diff --git a/.example/database/gdb/mysql/gdb_insert.go b/.example/database/gdb/mysql/gdb_insert.go index ee345b370..a6afa8491 100644 --- a/.example/database/gdb/mysql/gdb_insert.go +++ b/.example/database/gdb/mysql/gdb_insert.go @@ -10,9 +10,9 @@ func main() { //db := g.DB() gdb.AddDefaultConfigNode(gdb.ConfigNode{ - LinkInfo: "root:12345678@tcp(127.0.0.1:3306)/test?parseTime=true&loc=Local", - Type: "mysql", - Charset: "utf8", + Link: "root:12345678@tcp(127.0.0.1:3306)/test?parseTime=true&loc=Local", + Type: "mysql", + Charset: "utf8", }) db, _ := gdb.New() diff --git a/container/garray/garray_normal_any.go b/container/garray/garray_normal_any.go index 985c02fc7..a3b450b54 100644 --- a/container/garray/garray_normal_any.go +++ b/container/garray/garray_normal_any.go @@ -8,8 +8,8 @@ package garray import ( "bytes" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/empty" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/text/gstr" @@ -123,7 +123,7 @@ func (a *Array) Set(index int, value interface{}) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } a.array[index] = value return nil @@ -176,7 +176,7 @@ func (a *Array) InsertBefore(index int, value interface{}) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]interface{}{}, a.array[index:]...) a.array = append(a.array[0:index], value) @@ -189,7 +189,7 @@ func (a *Array) InsertAfter(index int, value interface{}) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]interface{}{}, a.array[index+1:]...) a.array = append(a.array[0:index+1], value) @@ -545,7 +545,7 @@ func (a *Array) Fill(startIndex int, num int, value interface{}) error { a.mu.Lock() defer a.mu.Unlock() if startIndex < 0 || startIndex > len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", startIndex, len(a.array))) + return gerror.Newf("index %d out of array range %d", startIndex, len(a.array)) } for i := startIndex; i < startIndex+num; i++ { if i > len(a.array)-1 { diff --git a/container/garray/garray_normal_int.go b/container/garray/garray_normal_int.go index 05a10f0fe..3e43a56db 100644 --- a/container/garray/garray_normal_int.go +++ b/container/garray/garray_normal_int.go @@ -8,8 +8,8 @@ package garray import ( "bytes" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "math" "sort" @@ -104,7 +104,7 @@ func (a *IntArray) Set(index int, value int) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } a.array[index] = value return nil @@ -175,7 +175,7 @@ func (a *IntArray) InsertBefore(index int, value int) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]int{}, a.array[index:]...) a.array = append(a.array[0:index], value) @@ -188,7 +188,7 @@ func (a *IntArray) InsertAfter(index int, value int) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]int{}, a.array[index+1:]...) a.array = append(a.array[0:index+1], value) @@ -559,7 +559,7 @@ func (a *IntArray) Fill(startIndex int, num int, value int) error { a.mu.Lock() defer a.mu.Unlock() if startIndex < 0 || startIndex > len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", startIndex, len(a.array))) + return gerror.Newf("index %d out of array range %d", startIndex, len(a.array)) } for i := startIndex; i < startIndex+num; i++ { if i > len(a.array)-1 { diff --git a/container/garray/garray_normal_str.go b/container/garray/garray_normal_str.go index f1754e80d..832a70cb6 100644 --- a/container/garray/garray_normal_str.go +++ b/container/garray/garray_normal_str.go @@ -8,8 +8,7 @@ package garray import ( "bytes" - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/text/gstr" "math" @@ -91,7 +90,7 @@ func (a *StrArray) Set(index int, value string) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } a.array[index] = value return nil @@ -163,7 +162,7 @@ func (a *StrArray) InsertBefore(index int, value string) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]string{}, a.array[index:]...) a.array = append(a.array[0:index], value) @@ -176,7 +175,7 @@ func (a *StrArray) InsertAfter(index int, value string) error { a.mu.Lock() defer a.mu.Unlock() if index < 0 || index >= len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", index, len(a.array))) + return gerror.Newf("index %d out of array range %d", index, len(a.array)) } rear := append([]string{}, a.array[index+1:]...) a.array = append(a.array[0:index+1], value) @@ -563,7 +562,7 @@ func (a *StrArray) Fill(startIndex int, num int, value string) error { a.mu.Lock() defer a.mu.Unlock() if startIndex < 0 || startIndex > len(a.array) { - return errors.New(fmt.Sprintf("index %d out of array range %d", startIndex, len(a.array))) + return gerror.Newf("index %d out of array range %d", startIndex, len(a.array)) } for i := startIndex; i < startIndex+num; i++ { if i > len(a.array)-1 { diff --git a/container/gpool/gpool.go b/container/gpool/gpool.go index 0ac919139..e4125aaac 100644 --- a/container/gpool/gpool.go +++ b/container/gpool/gpool.go @@ -8,7 +8,7 @@ package gpool import ( - "errors" + "github.com/gogf/gf/errors/gerror" "time" "github.com/gogf/gf/container/glist" @@ -66,7 +66,7 @@ func New(ttl time.Duration, newFunc NewFunc, expireFunc ...ExpireFunc) *Pool { // Put puts an item to pool. func (p *Pool) Put(value interface{}) error { if p.closed.Val() { - return errors.New("pool is closed") + return gerror.New("pool is closed") } item := &poolItem{ value: value, @@ -117,7 +117,7 @@ func (p *Pool) Get() (interface{}, error) { if p.NewFunc != nil { return p.NewFunc() } - return nil, errors.New("pool is empty") + return nil, gerror.New("pool is empty") } // Size returns the count of available items of pool. diff --git a/container/gset/gset_str_set.go b/container/gset/gset_str_set.go index 68b4ec6bc..24323c3e7 100644 --- a/container/gset/gset_str_set.go +++ b/container/gset/gset_str_set.go @@ -21,7 +21,7 @@ type StrSet struct { data map[string]struct{} } -// New create and returns a new set, which contains un-repeated items. +// NewStrSet create and returns a new set, which contains un-repeated items. // The parameter is used to specify whether using set in concurrent-safety, // which is false in default. func NewStrSet(safe ...bool) *StrSet { diff --git a/crypto/gaes/gaes.go b/crypto/gaes/gaes.go index 6e2f84d5a..bdb04d78a 100644 --- a/crypto/gaes/gaes.go +++ b/crypto/gaes/gaes.go @@ -11,7 +11,7 @@ import ( "bytes" "crypto/aes" "crypto/cipher" - "errors" + "github.com/gogf/gf/errors/gerror" ) var ( @@ -63,7 +63,7 @@ func DecryptCBC(cipherText []byte, key []byte, iv ...[]byte) ([]byte, error) { } blockSize := block.BlockSize() if len(cipherText) < blockSize { - return nil, errors.New("cipherText too short") + return nil, gerror.New("cipherText too short") } ivValue := ([]byte)(nil) if len(iv) > 0 { @@ -72,7 +72,7 @@ func DecryptCBC(cipherText []byte, key []byte, iv ...[]byte) ([]byte, error) { ivValue = []byte(IVDefaultValue) } if len(cipherText)%blockSize != 0 { - return nil, errors.New("cipherText is not a multiple of the block size") + return nil, gerror.New("cipherText is not a multiple of the block size") } blockModel := cipher.NewCBCDecrypter(block, ivValue) plainText := make([]byte, len(cipherText)) @@ -93,22 +93,22 @@ func PKCS5Padding(src []byte, blockSize int) []byte { func PKCS5UnPadding(src []byte, blockSize int) ([]byte, error) { length := len(src) if blockSize <= 0 { - return nil, errors.New("invalid blocklen") + return nil, gerror.New("invalid blocklen") } if length%blockSize != 0 || length == 0 { - return nil, errors.New("invalid data len") + return nil, gerror.New("invalid data len") } unpadding := int(src[length-1]) if unpadding > blockSize || unpadding == 0 { - return nil, errors.New("invalid padding") + return nil, gerror.New("invalid padding") } padding := src[length-unpadding:] for i := 0; i < unpadding; i++ { if padding[i] != byte(unpadding) { - return nil, errors.New("invalid padding") + return nil, gerror.New("invalid padding") } } @@ -146,7 +146,7 @@ func DecryptCFB(cipherText []byte, key []byte, unPadding int, iv ...[]byte) ([]b return nil, err } if len(cipherText) < aes.BlockSize { - return nil, errors.New("cipherText too short") + return nil, gerror.New("cipherText too short") } ivValue := ([]byte)(nil) if len(iv) > 0 { diff --git a/crypto/gdes/gdes.go b/crypto/gdes/gdes.go index 55d936dfe..c60f7c190 100644 --- a/crypto/gdes/gdes.go +++ b/crypto/gdes/gdes.go @@ -11,7 +11,7 @@ import ( "bytes" "crypto/cipher" "crypto/des" - "errors" + "github.com/gogf/gf/errors/gerror" ) const ( @@ -66,7 +66,7 @@ func DecryptECB(cipherText []byte, key []byte, padding int) ([]byte, error) { // The length of the should be either 16 or 24 bytes. func EncryptECBTriple(plainText []byte, key []byte, padding int) ([]byte, error) { if len(key) != 16 && len(key) != 24 { - return nil, errors.New("key length error") + return nil, gerror.New("key length error") } text, err := Padding(plainText, padding) @@ -100,7 +100,7 @@ func EncryptECBTriple(plainText []byte, key []byte, padding int) ([]byte, error) // The length of the should be either 16 or 24 bytes. func DecryptECBTriple(cipherText []byte, key []byte, padding int) ([]byte, error) { if len(key) != 16 && len(key) != 24 { - return nil, errors.New("key length error") + return nil, gerror.New("key length error") } var newKey []byte @@ -138,7 +138,7 @@ func EncryptCBC(plainText []byte, key []byte, iv []byte, padding int) ([]byte, e } if len(iv) != block.BlockSize() { - return nil, errors.New("iv length invalid") + return nil, gerror.New("iv length invalid") } text, err := Padding(plainText, padding) @@ -161,7 +161,7 @@ func DecryptCBC(cipherText []byte, key []byte, iv []byte, padding int) ([]byte, } if len(iv) != block.BlockSize() { - return nil, errors.New("iv length invalid") + return nil, gerror.New("iv length invalid") } text := make([]byte, len(cipherText)) @@ -179,7 +179,7 @@ func DecryptCBC(cipherText []byte, key []byte, iv []byte, padding int) ([]byte, // EncryptCBCTriple encrypts using TripleDES and CBC mode. func EncryptCBCTriple(plainText []byte, key []byte, iv []byte, padding int) ([]byte, error) { if len(key) != 16 && len(key) != 24 { - return nil, errors.New("key length invalid") + return nil, gerror.New("key length invalid") } var newKey []byte @@ -196,7 +196,7 @@ func EncryptCBCTriple(plainText []byte, key []byte, iv []byte, padding int) ([]b } if len(iv) != block.BlockSize() { - return nil, errors.New("iv length invalid") + return nil, gerror.New("iv length invalid") } text, err := Padding(plainText, padding) @@ -214,7 +214,7 @@ func EncryptCBCTriple(plainText []byte, key []byte, iv []byte, padding int) ([]b // DecryptCBCTriple decrypts <cipherText> using TripleDES and CBC mode. func DecryptCBCTriple(cipherText []byte, key []byte, iv []byte, padding int) ([]byte, error) { if len(key) != 16 && len(key) != 24 { - return nil, errors.New("key length invalid") + return nil, gerror.New("key length invalid") } var newKey []byte @@ -231,7 +231,7 @@ func DecryptCBCTriple(cipherText []byte, key []byte, iv []byte, padding int) ([] } if len(iv) != block.BlockSize() { - return nil, errors.New("iv length invalid") + return nil, gerror.New("iv length invalid") } text := make([]byte, len(cipherText)) @@ -262,12 +262,12 @@ func Padding(text []byte, padding int) ([]byte, error) { switch padding { case NOPADDING: if len(text)%8 != 0 { - return nil, errors.New("text length invalid") + return nil, gerror.New("text length invalid") } case PKCS5PADDING: return PaddingPKCS5(text, 8), nil default: - return nil, errors.New("padding type error") + return nil, gerror.New("padding type error") } return text, nil @@ -277,12 +277,12 @@ func UnPadding(text []byte, padding int) ([]byte, error) { switch padding { case NOPADDING: if len(text)%8 != 0 { - return nil, errors.New("text length invalid") + return nil, gerror.New("text length invalid") } case PKCS5PADDING: return UnPaddingPKCS5(text), nil default: - return nil, errors.New("padding type error") + return nil, gerror.New("padding type error") } return text, nil } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index ab415021a..f197e5675 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -38,6 +38,7 @@ type DB interface { // relational databases but also for NoSQL databases in the future. The name // "Table" is not proper for that purpose any more. // Also see Core.Table. + // Deprecated. Table(tableNameOrStruct ...interface{}) *Model // Model creates and returns a new ORM model from given schema. @@ -51,6 +52,9 @@ type DB interface { // Also see Core.Model. Model(tableNameOrStruct ...interface{}) *Model + // Raw creates and returns a model based on a raw sql not a table. + Raw(rawSql string, args ...interface{}) *Model + // Schema creates and returns a schema. // Also see Core.Schema. Schema(schema string) *Schema @@ -66,8 +70,6 @@ type DB interface { // Ctx is a chaining function, which creates and returns a new DB that is a shallow copy // of current DB object and with given context in it. - // Note that this returned DB object can be used only once, so do not assign it to - // a global or package variable for long using. // Also see Core.Ctx. Ctx(ctx context.Context) DB @@ -83,16 +85,11 @@ type DB interface { // Common APIs for CURD. // =========================================================================== - Insert(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Insert. - InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.InsertIgnore. - InsertAndGetId(table string, data interface{}, batch ...int) (int64, error) // See Core.InsertAndGetId. - Replace(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Replace. - Save(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Save. - - BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) // See Core.BatchInsert. - BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) // See Core.BatchReplace. - BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) // See Core.BatchSave. - + Insert(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Insert. + InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.InsertIgnore. + InsertAndGetId(table string, data interface{}, batch ...int) (int64, error) // See Core.InsertAndGetId. + Replace(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Replace. + Save(table string, data interface{}, batch ...int) (sql.Result, error) // See Core.Save. Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Update. Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Delete. @@ -100,27 +97,27 @@ type DB interface { // Internal APIs for CURD, which can be overwrote for custom CURD implements. // =========================================================================== - DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery. - DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec. - DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare. DoGetAll(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoGetAll. - DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoInsert. - DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoBatchInsert. + DoInsert(ctx context.Context, link Link, table string, data List, option DoInsertOption) (result sql.Result, err error) // See Core.DoInsert. DoUpdate(ctx context.Context, link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoUpdate. DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete. + DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery. + DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec. + DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoCommit. + DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare. // =========================================================================== // Query APIs for convenience purpose. // =========================================================================== - GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll. - GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne. - GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue. - GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray. - GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount. - GetStruct(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetStruct. - GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error // See Core.GetStructs. - GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan. + GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll. + GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne. + GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue. + GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray. + GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount. + GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan. + Union(unions ...*Model) *Model // See Core.Union. + UnionAll(unions ...*Model) *Model // See Core.UnionAll. // =========================================================================== // Master/Slave specification support. @@ -172,14 +169,7 @@ type DB interface { GetChars() (charLeft string, charRight string) // See Core.GetChars. Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables. TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. - FilteredLinkInfo() string // See Core.FilteredLinkInfo. - - // HandleSqlBeforeCommit is a hook function, which deals with the sql string before - // it's committed to underlying driver. The parameter `link` specifies the current - // database connection operation object. You can modify the sql string `sql` and its - // arguments `args` as you wish before they're committed to driver. - // Also see Core.HandleSqlBeforeCommit. - HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) + FilteredLink() string // FilteredLink is used for filtering sensitive information in `Link` configuration before output it to tracing server. } // Core is the base struct for database management. @@ -224,6 +214,14 @@ type Sql struct { IsTransaction bool // IsTransaction marks whether this sql is executed in transaction. } +// DoInsertOption is the input struct for function DoInsert. +type DoInsertOption struct { + OnDuplicateStr string + OnDuplicateMap map[string]interface{} + InsertOption int // Insert operation. + BatchCount int // Batch count for batch inserting. +} + // TableField is the struct for table field. type TableField struct { Index int // For ordering purpose as map is unordered. @@ -252,6 +250,10 @@ type ( ) const ( + queryTypeNormal = 0 + queryTypeCount = 1 + unionTypeNormal = 0 + unionTypeAll = 1 insertOptionDefault = 0 insertOptionReplace = 1 insertOptionSave = 2 @@ -263,6 +265,9 @@ const ( ctxTimeoutTypeExec = iota ctxTimeoutTypeQuery ctxTimeoutTypePrepare + commandEnvKeyForDryRun = "gf.gdb.dryrun" + ctxStrictKeyName = "gf.gdb.CtxStrictEnabled" + ctxStrictErrorStr = "context is required for database operation, did you missing call function Ctx" ) var ( @@ -307,7 +312,7 @@ var ( func init() { // allDryRun is initialized from environment or command options. - allDryRun = gcmd.GetOptWithEnv("gf.gdb.dryrun", false).Bool() + allDryRun = gcmd.GetOptWithEnv(commandEnvKeyForDryRun, false).Bool() } // Register registers custom database driver to gdb. @@ -479,14 +484,16 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error // Cache the underlying connection pool object by node. v, _ := internalCache.GetOrSetFuncLock(node.String(), func() (interface{}, error) { intlog.Printf( + c.db.GetCtx(), `open new connection, master:%#v, config:%#v, node:%#v`, master, c.config, node, ) defer func() { if err != nil { - intlog.Printf(`open new connection failed: %v, %#v`, err, node) + intlog.Printf(c.db.GetCtx(), `open new connection failed: %v, %#v`, err, node) } else { intlog.Printf( + c.db.GetCtx(), `open new connection success, master:%#v, config:%#v, node:%#v`, master, c.config, node, ) diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 84d50bbd9..ab11e9b66 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -40,6 +40,7 @@ func (c *Core) Ctx(ctx context.Context) DB { if c.ctx != nil { return c.db } + ctx = context.WithValue(ctx, ctxStrictKeyName, 1) // It makes a shallow copy of current db and changes its context for next chaining operation. var ( err error @@ -189,9 +190,9 @@ func (c *Core) GetScan(pointer interface{}, sql string, args ...interface{}) err k = t.Elem().Kind() switch k { case reflect.Array, reflect.Slice: - return c.db.GetStructs(pointer, sql, args...) + return c.db.GetCore().GetStructs(pointer, sql, args...) case reflect.Struct: - return c.db.GetStruct(pointer, sql, args...) + return c.db.GetCore().GetStruct(pointer, sql, args...) } return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k) } @@ -224,6 +225,39 @@ func (c *Core) GetCount(sql string, args ...interface{}) (int, error) { return value.Int(), nil } +// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement. +func (c *Core) Union(unions ...*Model) *Model { + return c.doUnion(unionTypeNormal, unions...) +} + +// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement. +func (c *Core) UnionAll(unions ...*Model) *Model { + return c.doUnion(unionTypeAll, unions...) +} + +func (c *Core) doUnion(unionType int, unions ...*Model) *Model { + var ( + unionTypeStr string + composedSqlStr string + composedArgs = make([]interface{}, 0) + ) + if unionType == unionTypeAll { + unionTypeStr = "UNION ALL" + } else { + unionTypeStr = "UNION" + } + for _, v := range unions { + sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(queryTypeNormal, false) + if composedSqlStr == "" { + composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder) + } else { + composedSqlStr += fmt.Sprintf(` %s (%s)`, unionTypeStr, sqlWithHolder) + } + composedArgs = append(composedArgs, holderArgs...) + } + return c.db.Raw(composedSqlStr, composedArgs...) +} + // PingMaster pings the master node to check authentication or keeps the connection alive. func (c *Core) PingMaster() error { if master, err := c.db.Master(); err != nil { @@ -331,181 +365,15 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e // 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; // 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one; // 3: ignore: if there's unique/primary key in the data, it ignores the inserting; -func (c *Core) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { - table = c.QuotePrefixTableName(table) +func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) { var ( - fields []string - values []string - params []interface{} - dataMap Map - reflectValue = reflect.ValueOf(data) - reflectKind = reflectValue.Kind() + keys []string // Field names. + values []string // Value holder string array, like: (?,?,?) + params []interface{} // Values that will be committed to underlying database driver. + onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement. ) - if reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { - case reflect.Slice, reflect.Array: - return c.db.DoBatchInsert(ctx, link, table, data, option, batch...) - case reflect.Struct: - if _, ok := data.(apiInterfaces); ok { - return c.db.DoBatchInsert(ctx, link, table, data, option, batch...) - } else { - dataMap = ConvertDataForTableRecord(data) - } - case reflect.Map: - dataMap = ConvertDataForTableRecord(data) - default: - return result, gerror.New(fmt.Sprint("unsupported data type:", reflectKind)) - } - if len(dataMap) == 0 { - return nil, gerror.New("data cannot be empty") - } - var ( - charL, charR = c.db.GetChars() - operation = GetInsertOperationByOption(option) - updateStr = "" - ) - for k, v := range dataMap { - fields = append(fields, charL+k+charR) - if s, ok := v.(Raw); ok { - values = append(values, gconv.String(s)) - } else { - values = append(values, "?") - params = append(params, v) - } - } - if option == insertOptionSave { - for k, _ := range dataMap { - // If it's SAVE operation, - // do not automatically update the creating time. - if c.isSoftCreatedFiledName(k) { - continue - } - if len(updateStr) > 0 { - updateStr += "," - } - updateStr += fmt.Sprintf( - "%s%s%s=VALUES(%s%s%s)", - charL, k, charR, - charL, k, charR, - ) - } - updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr) - } - if link == nil { - if link, err = c.MasterLink(); err != nil { - return nil, err - } - } - return c.db.DoExec(ctx, link, fmt.Sprintf( - "%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(fields, ","), - strings.Join(values, ","), updateStr, - ), params...) -} - -// BatchInsert batch inserts data. -// The parameter `list` must be type of slice of map or struct. -func (c *Core) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return c.Model(table).Data(list).Batch(batch[0]).Insert() - } - return c.Model(table).Data(list).Insert() -} - -// BatchInsertIgnore batch inserts data with ignore option. -// The parameter `list` must be type of slice of map or struct. -func (c *Core) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return c.Model(table).Data(list).Batch(batch[0]).InsertIgnore() - } - return c.Model(table).Data(list).InsertIgnore() -} - -// BatchReplace batch replaces data. -// The parameter `list` must be type of slice of map or struct. -func (c *Core) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return c.Model(table).Data(list).Batch(batch[0]).Replace() - } - return c.Model(table).Data(list).Replace() -} - -// BatchSave batch replaces data. -// The parameter `list` must be type of slice of map or struct. -func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return c.Model(table).Data(list).Batch(batch[0]).Save() - } - return c.Model(table).Data(list).Save() -} - -// DoBatchInsert batch inserts/replaces/saves data. -// This function is usually used for custom interface definition, you do not need call it manually. -func (c *Core) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { - table = c.QuotePrefixTableName(table) - var ( - keys []string // Field names. - values []string // Value holder string array, like: (?,?,?) - params []interface{} // Values that will be committed to underlying database driver. - listMap List // The data list that passed from caller. - ) - switch value := list.(type) { - case Result: - listMap = value.List() - case Record: - listMap = List{value.Map()} - case List: - listMap = value - case Map: - listMap = List{value} - default: - var ( - rv = reflect.ValueOf(list) - kind = rv.Kind() - ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - // If it's slice type, it then converts it to List type. - case reflect.Slice, reflect.Array: - listMap = make(List, rv.Len()) - for i := 0; i < rv.Len(); i++ { - listMap[i] = ConvertDataForTableRecord(rv.Index(i).Interface()) - } - case reflect.Map: - listMap = List{ConvertDataForTableRecord(value)} - case reflect.Struct: - if v, ok := value.(apiInterfaces); ok { - var ( - array = v.Interfaces() - list = make(List, len(array)) - ) - for i := 0; i < len(array); i++ { - list[i] = ConvertDataForTableRecord(array[i]) - } - listMap = list - } else { - listMap = List{ConvertDataForTableRecord(value)} - } - default: - return result, gerror.New(fmt.Sprint("unsupported list type:", kind)) - } - } - if len(listMap) < 1 { - return result, gerror.New("data list cannot be empty") - } - if link == nil { - if link, err = c.MasterLink(); err != nil { - return - } - } // Handle the field names and place holders. - for k, _ := range listMap[0] { + for k, _ := range list[0] { keys = append(keys, k) } // Prepare the batch result pointer. @@ -513,54 +381,35 @@ func (c *Core) DoBatchInsert(ctx context.Context, link Link, table string, list charL, charR = c.db.GetChars() batchResult = new(SqlResult) keysStr = charL + strings.Join(keys, charR+","+charL) + charR - operation = GetInsertOperationByOption(option) - updateStr = "" + operation = GetInsertOperationByOption(option.InsertOption) ) - if option == insertOptionSave { - for _, k := range keys { - // If it's SAVE operation, - // do not automatically update the creating time. - if c.isSoftCreatedFiledName(k) { - continue - } - if len(updateStr) > 0 { - updateStr += "," - } - updateStr += fmt.Sprintf( - "%s%s%s=VALUES(%s%s%s)", - charL, k, charR, - charL, k, charR, - ) - } - updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr) - } - batchNum := defaultBatchNumber - if len(batch) > 0 && batch[0] > 0 { - batchNum = batch[0] + if option.InsertOption == insertOptionSave { + onDuplicateStr = c.formatOnDuplicate(keys, option) } var ( - listMapLen = len(listMap) + listLength = len(list) valueHolder = make([]string, 0) ) - for i := 0; i < listMapLen; i++ { + for i := 0; i < listLength; i++ { values = values[:0] // Note that the map type is unordered, // so it should use slice+key to retrieve the value. for _, k := range keys { - if s, ok := listMap[i][k].(Raw); ok { + if s, ok := list[i][k].(Raw); ok { values = append(values, gconv.String(s)) } else { values = append(values, "?") - params = append(params, listMap[i][k]) + params = append(params, list[i][k]) } } valueHolder = append(valueHolder, "("+gstr.Join(values, ",")+")") - if len(valueHolder) == batchNum || (i == listMapLen-1 && len(valueHolder) > 0) { + // Batch package checks: It meets the batch number or it is the last element. + if len(valueHolder) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) { r, err := c.db.DoExec(ctx, link, fmt.Sprintf( "%s INTO %s(%s) VALUES%s %s", - operation, table, keysStr, + operation, c.QuotePrefixTableName(table), keysStr, gstr.Join(valueHolder, ","), - updateStr, + onDuplicateStr, ), params...) if err != nil { return r, err @@ -578,6 +427,52 @@ func (c *Core) DoBatchInsert(ctx context.Context, link Link, table string, list return batchResult, nil } +func (c *Core) formatOnDuplicate(columns []string, option DoInsertOption) string { + var ( + onDuplicateStr string + ) + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case Raw, *Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + c.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(k), + c.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, + // do not automatically update the creating time. + if c.isSoftCreatedFilledName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(column), + c.QuoteWord(column), + ) + } + } + return fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", onDuplicateStr) +} + // Update does "UPDATE ... " statement for the table. // // The parameter `data` can be type of string/map/gmap/struct/*struct, etc. @@ -711,7 +606,7 @@ func (c *Core) convertRowsToResult(rows *sql.Rows) (Result, error) { } var ( values = make([]interface{}, len(columnNames)) - records = make(Result, 0) + result = make(Result, 0) scanArgs = make([]interface{}, len(values)) ) for i := range values { @@ -719,22 +614,22 @@ func (c *Core) convertRowsToResult(rows *sql.Rows) (Result, error) { } for { if err := rows.Scan(scanArgs...); err != nil { - return records, err + return result, err } - row := make(Record) + record := Record{} for i, value := range values { if value == nil { - row[columnNames[i]] = gvar.New(nil) + record[columnNames[i]] = gvar.New(nil) } else { - row[columnNames[i]] = gvar.New(c.convertFieldValueToLocalValue(value, columnTypes[i])) + record[columnNames[i]] = gvar.New(c.convertFieldValueToLocalValue(value, columnTypes[i])) } } - records = append(records, row) + result = append(result, record) if !rows.Next() { break } } - return records, nil + return result, nil } // MarshalJSON implements the interface MarshalJSON for json.Marshal. @@ -778,8 +673,8 @@ func (c *Core) HasTable(name string) (bool, error) { return false, nil } -// isSoftCreatedFiledName checks and returns whether given filed name is an automatic-filled created time. -func (c *Core) isSoftCreatedFiledName(fieldName string) bool { +// isSoftCreatedFilledName checks and returns whether given filed name is an automatic-filled created time. +func (c *Core) isSoftCreatedFilledName(fieldName string) bool { if fieldName == "" { return false } diff --git a/database/gdb/gdb_core_config.go b/database/gdb/gdb_core_config.go index 8ce85dc8e..0982d6130 100644 --- a/database/gdb/gdb_core_config.go +++ b/database/gdb/gdb_core_config.go @@ -30,13 +30,14 @@ type ConfigNode struct { Pass string `json:"pass"` // Authentication password. Name string `json:"name"` // Default used database name. Type string `json:"type"` // Database type: mysql, sqlite, mssql, pgsql, oracle. + Link string `json:"link"` // (Optional) Custom link information, when it is used, configuration Host/Port/User/Pass/Name are ignored. Role string `json:"role"` // (Optional, "master" in default) Node role, used for master-slave mode: master, slave. Debug bool `json:"debug"` // (Optional) Debug mode enables debug information logging and output. Prefix string `json:"prefix"` // (Optional) Table prefix. DryRun bool `json:"dryRun"` // (Optional) Dry run, which does SELECT but no INSERT/UPDATE/DELETE statements. Weight int `json:"weight"` // (Optional) Weight for load balance calculating, it's useless if there's just one node. Charset string `json:"charset"` // (Optional, "utf8mb4" in default) Custom charset when operating on database. - LinkInfo string `json:"link"` // (Optional) Custom link information, when it is used, configuration Host/Port/User/Pass/Name are ignored. + Timezone string `json:"timezone"` // (Optional) Sets the time zone for displaying and interpreting time stamps. MaxIdleConnCount int `json:"maxIdle"` // (Optional) Max idle connection configuration for underlying connection pool. MaxOpenConnCount int `json:"maxOpen"` // (Optional) Max open connection configuration for underlying connection pool. MaxConnLifeTime time.Duration `json:"maxLifeTime"` // (Optional) Max amount of time a connection may be idle before being closed. @@ -48,6 +49,7 @@ type ConfigNode struct { UpdatedAt string `json:"updatedAt"` // (Optional) The filed name of table for automatic-filled updated datetime. DeletedAt string `json:"deletedAt"` // (Optional) The filed name of table for automatic-filled updated datetime. TimeMaintainDisabled bool `json:"timeMaintainDisabled"` // (Optional) Disable the automatic time maintaining feature. + CtxStrict bool `json:"ctxStrict"` // (Optional) Strictly require context input for all database operations. } const ( @@ -186,7 +188,7 @@ func (node *ConfigNode) String() string { node.MaxIdleConnCount, node.MaxOpenConnCount, node.MaxConnLifeTime, - node.LinkInfo, + node.Link, ) } diff --git a/database/gdb/gdb_core_tracing.go b/database/gdb/gdb_core_tracing.go index 3c746ec2c..0b0bc6d59 100644 --- a/database/gdb/gdb_core_tracing.go +++ b/database/gdb/gdb_core_tracing.go @@ -12,7 +12,6 @@ import ( "fmt" "github.com/gogf/gf" "github.com/gogf/gf/net/gtrace" - "github.com/gogf/gf/os/gcmd" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -34,19 +33,9 @@ const ( tracingEventDbExecutionType = "db.execution.type" ) -var ( - // tracingInternal enables tracing for internal type spans. - // It's true in default. - tracingInternal = true -) - -func init() { - tracingInternal = gcmd.GetOptWithEnv("gf.tracing.internal", true).Bool() -} - // addSqlToTracing adds sql information to tracer if it's enabled. func (c *Core) addSqlToTracing(ctx context.Context, sql *Sql) { - if !tracingInternal || !gtrace.IsActivated(ctx) { + if !gtrace.IsTracingInternal() || !gtrace.IsActivated(ctx) { return } tr := otel.GetTracerProvider().Tracer( @@ -76,8 +65,8 @@ func (c *Core) addSqlToTracing(ctx context.Context, sql *Sql) { if c.db.GetConfig().User != "" { labels = append(labels, attribute.String(tracingAttrDbUser, c.db.GetConfig().User)) } - if filteredLinkInfo := c.db.FilteredLinkInfo(); filteredLinkInfo != "" { - labels = append(labels, attribute.String(tracingAttrDbLink, c.db.FilteredLinkInfo())) + if filteredLink := c.db.FilteredLink(); filteredLink != "" { + labels = append(labels, attribute.String(tracingAttrDbLink, c.db.FilteredLink())) } if group := c.db.GetGroup(); group != "" { labels = append(labels, attribute.String(tracingAttrDbGroup, group)) diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index d9e575ce8..0252f1096 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -28,6 +28,7 @@ type TX struct { master *sql.DB // master is the raw and underlying database manager. transactionId string // transactionId is an unique id generated by this object for this transaction. transactionCount int // transactionCount marks the times that Begins. + isClosed bool // isClosed marks this transaction has already been committed or rolled back. } const ( @@ -162,6 +163,9 @@ func TXFromCtx(ctx context.Context, group string) *TX { v := ctx.Value(transactionKeyForContext(group)) if v != nil { tx := v.(*TX) + if tx.IsClosed() { + return nil + } tx.ctx = ctx return tx } @@ -210,6 +214,7 @@ func (tx *TX) Commit() error { IsTransaction: true, } ) + tx.isClosed = true tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj) if tx.db.GetDebug() { tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj) @@ -243,6 +248,7 @@ func (tx *TX) Rollback() error { IsTransaction: true, } ) + tx.isClosed = true tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj) if tx.db.GetDebug() { tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj) @@ -250,6 +256,11 @@ func (tx *TX) Rollback() error { return err } +// IsClosed checks and returns this transaction has already been committed or rolled back. +func (tx *TX) IsClosed() bool { + return tx.isClosed +} + // Begin starts a nested transaction procedure. func (tx *TX) Begin() error { _, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint()) @@ -503,42 +514,6 @@ func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, er return tx.Model(table).Ctx(tx.ctx).Data(data).Save() } -// BatchInsert batch inserts data. -// The parameter `list` must be type of slice of map or struct. -func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Insert() - } - return tx.Model(table).Ctx(tx.ctx).Data(list).Insert() -} - -// BatchInsertIgnore batch inserts data with ignore option. -// The parameter `list` must be type of slice of map or struct. -func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).InsertIgnore() - } - return tx.Model(table).Ctx(tx.ctx).Data(list).InsertIgnore() -} - -// BatchReplace batch replaces data. -// The parameter `list` must be type of slice of map or struct. -func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Replace() - } - return tx.Model(table).Ctx(tx.ctx).Data(list).Replace() -} - -// BatchSave batch replaces data. -// The parameter `list` must be type of slice of map or struct. -func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { - if len(batch) > 0 { - return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Save() - } - return tx.Model(table).Ctx(tx.ctx).Data(list).Save() -} - // Update does "UPDATE ... " statement for the table. // // The parameter `data` can be type of string/map/gmap/struct/*struct, etc. diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index baa0e5027..ba09b503f 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -10,6 +10,7 @@ package gdb import ( "context" "database/sql" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" ) @@ -33,12 +34,17 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter link = &txLink{tx.tx} } } - // Link execution. - sql, args = formatSql(sql, args) - sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args) + if c.GetConfig().QueryTimeout > 0 { ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout) } + + // Link execution. + sql, args = formatSql(sql, args) + sql, args, err = c.db.DoCommit(ctx, link, sql, args) + if err != nil { + return nil, err + } mTime1 := gtime.TimestampMilli() rows, err = link.QueryContext(ctx, sql, args...) mTime2 := gtime.TimestampMilli() @@ -85,15 +91,19 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf link = &txLink{tx.tx} } } - // Link execution. - sql, args = formatSql(sql, args) - sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args) + if c.GetConfig().ExecTimeout > 0 { var cancelFunc context.CancelFunc ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().ExecTimeout) defer cancelFunc() } + // Link execution. + sql, args = formatSql(sql, args) + sql, args, err = c.db.DoCommit(ctx, link, sql, args) + if err != nil { + return nil, err + } mTime1 := gtime.TimestampMilli() if !c.db.GetDryRun() { result, err = link.ExecContext(ctx, sql, args...) @@ -120,6 +130,18 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf return result, formatError(err, sql, args...) } +// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver. +// The parameter `link` specifies the current database connection operation object. You can modify the sql +// string `sql` and its arguments `args` as you wish before they're committed to driver. +func (c *Core) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + if c.db.GetConfig().CtxStrict { + if v := ctx.Value(ctxStrictKeyName); v == nil { + return sql, args, gerror.New(ctxStrictErrorStr) + } + } + return sql, args, nil +} + // Prepare creates a prepared statement for later queries or executions. // Multiple queries or executions may be run concurrently from the // returned statement. @@ -156,6 +178,13 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, err // DO NOT USE cancel function in prepare statement. ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout) } + + if c.db.GetConfig().CtxStrict { + if v := ctx.Value(ctxStrictKeyName); v == nil { + return nil, gerror.New(ctxStrictErrorStr) + } + } + var ( mTime1 = gtime.TimestampMilli() stmt, err = link.PrepareContext(ctx, sql) diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 0eaeb16c4..87add6748 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -63,12 +63,6 @@ func (c *Core) GetChars() (charLeft string, charRight string) { return "", "" } -// HandleSqlBeforeCommit handles the sql before posts it to database. -// It does nothing in default. -func (c *Core) HandleSqlBeforeCommit(sql string) string { - return sql -} - // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. // diff --git a/database/gdb/gdb_driver_mssql.go b/database/gdb/gdb_driver_mssql.go index 33d3f8732..24156605c 100644 --- a/database/gdb/gdb_driver_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -42,15 +42,15 @@ func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) { // Open creates and returns a underlying sql.DB object for mssql. func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) { source := "" - if config.LinkInfo != "" { - source = config.LinkInfo + if config.Link != "" { + source = config.Link } else { source = fmt.Sprintf( "user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", config.User, config.Pass, config.Host, config.Port, config.Name, ) } - intlog.Printf("Open: %s", source) + intlog.Printf(d.GetCtx(), "Open: %s", source) if db, err := sql.Open("sqlserver", source); err == nil { return db, nil } else { @@ -58,17 +58,17 @@ func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) { } } -// FilteredLinkInfo retrieves and returns filtered `linkInfo` that can be using for +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for // logging or tracing purpose. -func (d *DriverMssql) FilteredLinkInfo() string { - linkInfo := d.GetConfig().LinkInfo +func (d *DriverMssql) FilteredLink() string { + linkInfo := d.GetConfig().Link if linkInfo == "" { return "" } s, _ := gregex.ReplaceString( `(.+);\s*password=(.+);\s*server=(.+)`, `$1;password=xxx;server=$3`, - d.GetConfig().LinkInfo, + d.GetConfig().Link, ) return s } @@ -78,8 +78,11 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverMssql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { +// DoCommit deals with the sql string before commits it to underlying sql driver. +func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + defer func() { + newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + }() var index int // Convert place holder char '?' to string "@px". str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -87,7 +90,7 @@ func (d *DriverMssql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql return fmt.Sprintf("@p%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return d.parseSql(str), args + return d.parseSql(str), args, nil } // parseSql does some replacement of the sql before commits it to underlying driver, diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index aab15faec..a9ac18296 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "fmt" + "net/url" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" @@ -36,8 +37,8 @@ func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) { // Note that it converts time.Time argument to local timezone in default. func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) { var source string - if config.LinkInfo != "" { - source = config.LinkInfo + if config.Link != "" { + source = config.Link // Custom changing the schema in runtime. if config.Name != "" { source, _ = gregex.ReplaceString(`/([\w\.\-]+)+`, "/"+config.Name, source) @@ -47,8 +48,11 @@ func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) { "%s:%s@tcp(%s:%s)/%s?charset=%s", config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset, ) + if config.Timezone != "" { + source = fmt.Sprintf("%s&loc=%s", source, url.QueryEscape(config.Timezone)) + } } - intlog.Printf("Open: %s", source) + intlog.Printf(d.GetCtx(), "Open: %s", source) if db, err := sql.Open("mysql", source); err == nil { return db, nil } else { @@ -56,10 +60,10 @@ func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) { } } -// FilteredLinkInfo retrieves and returns filtered `linkInfo` that can be using for +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for // logging or tracing purpose. -func (d *DriverMysql) FilteredLinkInfo() string { - linkInfo := d.GetConfig().LinkInfo +func (d *DriverMysql) FilteredLink() string { + linkInfo := d.GetConfig().Link if linkInfo == "" { return "" } @@ -76,9 +80,9 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) { return "`", "`" } -// HandleSqlBeforeCommit handles the sql before posts it to database. -func (d *DriverMysql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { - return sql, args +// DoCommit handles the sql before posts it to database. +func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + return d.Core.DoCommit(ctx, link, sql, args) } // Tables retrieves and returns the tables of current schema. diff --git a/database/gdb/gdb_driver_oracle.go b/database/gdb/gdb_driver_oracle.go index 9d8543936..8c175f228 100644 --- a/database/gdb/gdb_driver_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -32,11 +32,6 @@ type DriverOracle struct { *Core } -const ( - tableAlias1 = "GFORM1" - tableAlias2 = "GFORM2" -) - // New creates and returns a database object for oracle. // It implements the interface of gdb.Driver for extra database driver installation. func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) { @@ -48,15 +43,15 @@ func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) { // Open creates and returns a underlying sql.DB object for oracle. func (d *DriverOracle) Open(config *ConfigNode) (*sql.DB, error) { var source string - if config.LinkInfo != "" { - source = config.LinkInfo + if config.Link != "" { + source = config.Link } else { source = fmt.Sprintf( "%s/%s@%s:%s/%s", config.User, config.Pass, config.Host, config.Port, config.Name, ) } - intlog.Printf("Open: %s", source) + intlog.Printf(d.GetCtx(), "Open: %s", source) if db, err := sql.Open("oci8", source); err == nil { return db, nil } else { @@ -64,10 +59,10 @@ func (d *DriverOracle) Open(config *ConfigNode) (*sql.DB, error) { } } -// FilteredLinkInfo retrieves and returns filtered `linkInfo` that can be using for +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for // logging or tracing purpose. -func (d *DriverOracle) FilteredLinkInfo() string { - linkInfo := d.GetConfig().LinkInfo +func (d *DriverOracle) FilteredLink() string { + linkInfo := d.GetConfig().Link if linkInfo == "" { return "" } @@ -84,8 +79,12 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverOracle) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) { +// DoCommit deals with the sql string before commits it to underlying sql driver. +func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + defer func() { + newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + }() + var index int // Convert place holder char '?' to string ":vx". newSql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -264,203 +263,40 @@ func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[ return } -func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { - var ( - fields []string - values []string - params []interface{} - dataMap Map - rv = reflect.ValueOf(data) - kind = rv.Kind() - ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - case reflect.Slice, reflect.Array: - return d.DoBatchInsert(ctx, link, table, data, option, batch...) - case reflect.Map: - fallthrough - case reflect.Struct: - dataMap = ConvertDataForTableRecord(data) - default: - return result, gerror.New(fmt.Sprint("unsupported data type:", kind)) - } - var ( - indexes = make([]string, 0) - indexMap = make(map[string]string) - indexExists = false - ) - if option != insertOptionDefault { - index, err := d.getTableUniqueIndex(table) - if err != nil { - return nil, err - } - - if len(index) > 0 { - for _, v := range index { - for k, _ := range v { - indexes = append(indexes, k) - } - indexMap = v - indexExists = true - break - } - } - } - var ( - subSqlStr = make([]string, 0) - onStr = make([]string, 0) - updateStr = make([]string, 0) - ) - charL, charR := d.db.GetChars() - for k, v := range dataMap { - k = strings.ToUpper(k) - - // 操作类型为REPLACE/SAVE时且存在唯一索引才使用merge,否则使用insert - if (option == insertOptionReplace || option == insertOptionSave) && indexExists { - fields = append(fields, tableAlias1+"."+charL+k+charR) - values = append(values, tableAlias2+"."+charL+k+charR) - params = append(params, v) - subSqlStr = append(subSqlStr, fmt.Sprintf("%s?%s %s", charL, charR, k)) - //m erge中的on子句中由唯一索引组成, update子句中不含唯一索引 - if _, ok := indexMap[k]; ok { - onStr = append(onStr, fmt.Sprintf("%s.%s = %s.%s ", tableAlias1, k, tableAlias2, k)) - } else { - updateStr = append(updateStr, fmt.Sprintf("%s.%s = %s.%s ", tableAlias1, k, tableAlias2, k)) - } - } else { - fields = append(fields, charL+k+charR) - values = append(values, "?") - params = append(params, v) - } - } - - if link == nil { - if link, err = d.MasterLink(); err != nil { - return nil, err - } - } - - if indexExists && option != insertOptionDefault { - switch option { - case - insertOptionReplace, - insertOptionSave: - tmp := fmt.Sprintf( - "MERGE INTO %s %s USING(SELECT %s FROM DUAL) %s ON(%s) WHEN MATCHED THEN UPDATE SET %s WHEN NOT MATCHED THEN INSERT (%s) VALUES(%s)", - table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2, - strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","), - ) - return d.DoExec(ctx, link, tmp, params...) - - case insertOptionIgnore: - return d.DoExec(ctx, link, fmt.Sprintf( - "INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)", - table, strings.Join(indexes, ","), table, strings.Join(fields, ","), strings.Join(values, ","), - ), params...) - } - } - - return d.DoExec(ctx, link, - fmt.Sprintf( - "INSERT INTO %s(%s) VALUES(%s)", - table, strings.Join(fields, ","), strings.Join(values, ","), - ), - params...) -} - -func (d *DriverOracle) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { +func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) { var ( keys []string values []string params []interface{} ) - listMap := (List)(nil) - switch v := list.(type) { - case Result: - listMap = v.List() - case Record: - listMap = List{v.Map()} - case List: - listMap = v - case Map: - listMap = List{v} - default: - var ( - rv = reflect.ValueOf(list) - kind = rv.Kind() - ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - case reflect.Slice, reflect.Array: - listMap = make(List, rv.Len()) - for i := 0; i < rv.Len(); i++ { - listMap[i] = ConvertDataForTableRecord(rv.Index(i).Interface()) - } - case reflect.Map: - fallthrough - case reflect.Struct: - listMap = List{ConvertDataForTableRecord(list)} - default: - return result, gerror.New(fmt.Sprint("unsupported list type:", kind)) - } - } - if len(listMap) < 1 { - return result, gerror.New("empty data list") - } - if link == nil { - if link, err = d.MasterLink(); err != nil { - return - } - } // Retrieve the table fields and length. - holders := []string(nil) - for k, _ := range listMap[0] { + var ( + listLength = len(list) + valueHolder = make([]string, 0) + ) + for k, _ := range list[0] { keys = append(keys, k) - holders = append(holders, "?") + valueHolder = append(valueHolder, "?") } var ( batchResult = new(SqlResult) charL, charR = d.db.GetChars() keyStr = charL + strings.Join(keys, charL+","+charR) + charR - valueHolderStr = strings.Join(holders, ",") + valueHolderStr = strings.Join(valueHolder, ",") ) - if option != insertOptionDefault { - for _, v := range listMap { - r, err := d.DoInsert(ctx, link, table, v, option, 1) - if err != nil { - return r, err - } - - if n, err := r.RowsAffected(); err != nil { - return r, err - } else { - batchResult.result = r - batchResult.affected += n - } - } - return batchResult, nil - } - - batchNum := defaultBatchNumber - if len(batch) > 0 { - batchNum = batch[0] - } // Format "INSERT...INTO..." statement. intoStr := make([]string, 0) - for i := 0; i < len(listMap); i++ { + for i := 0; i < len(list); i++ { for _, k := range keys { - params = append(params, listMap[i][k]) + params = append(params, list[i][k]) } values = append(values, valueHolderStr) - intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr)) - if len(intoStr) == batchNum { - r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) + intoStr = append(intoStr, fmt.Sprintf("INTO %s(%s) VALUES(%s)", table, keyStr, valueHolderStr)) + if len(intoStr) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) { + r, err := d.DoExec(ctx, link, fmt.Sprintf( + "INSERT ALL %s SELECT * FROM DUAL", + strings.Join(intoStr, " "), + ), params...) if err != nil { return r, err } @@ -474,18 +310,5 @@ func (d *DriverOracle) DoBatchInsert(ctx context.Context, link Link, table strin intoStr = intoStr[:0] } } - // The leftover data. - if len(intoStr) > 0 { - r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) - if err != nil { - return r, err - } - if n, err := r.RowsAffected(); err != nil { - return r, err - } else { - batchResult.result = r - batchResult.affected += n - } - } return batchResult, nil } diff --git a/database/gdb/gdb_driver_pgsql.go b/database/gdb/gdb_driver_pgsql.go index e167a64eb..a1ada0252 100644 --- a/database/gdb/gdb_driver_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -40,15 +40,18 @@ func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) { // Open creates and returns a underlying sql.DB object for pgsql. func (d *DriverPgsql) Open(config *ConfigNode) (*sql.DB, error) { var source string - if config.LinkInfo != "" { - source = config.LinkInfo + if config.Link != "" { + source = config.Link } else { source = fmt.Sprintf( "user=%s password=%s host=%s port=%s dbname=%s sslmode=disable", config.User, config.Pass, config.Host, config.Port, config.Name, ) + if config.Timezone != "" { + source = fmt.Sprintf("%s timezone=%s", source, config.Timezone) + } } - intlog.Printf("Open: %s", source) + intlog.Printf(d.GetCtx(), "Open: %s", source) if db, err := sql.Open("postgres", source); err == nil { return db, nil } else { @@ -56,10 +59,10 @@ func (d *DriverPgsql) Open(config *ConfigNode) (*sql.DB, error) { } } -// FilteredLinkInfo retrieves and returns filtered `linkInfo` that can be using for +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for // logging or tracing purpose. -func (d *DriverPgsql) FilteredLinkInfo() string { - linkInfo := d.GetConfig().LinkInfo +func (d *DriverPgsql) FilteredLink() string { + linkInfo := d.GetConfig().Link if linkInfo == "" { return "" } @@ -76,16 +79,20 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverPgsql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { +// DoCommit deals with the sql string before commits it to underlying sql driver. +func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + defer func() { + newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + }() + var index int // Convert place holder char '?' to string "$x". sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string { index++ return fmt.Sprintf("$%d", index) }) - sql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql) - return sql, args + newSql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql) + return newSql, args, nil } // Tables retrieves and returns the tables of current schema. diff --git a/database/gdb/gdb_driver_sqlite.go b/database/gdb/gdb_driver_sqlite.go index 99ae55fcf..1ad0f68b3 100644 --- a/database/gdb/gdb_driver_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -38,8 +38,8 @@ func (d *DriverSqlite) New(core *Core, node *ConfigNode) (DB, error) { // Open creates and returns a underlying sql.DB object for sqlite. func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) { var source string - if config.LinkInfo != "" { - source = config.LinkInfo + if config.Link != "" { + source = config.Link } else { source = config.Name } @@ -47,7 +47,7 @@ func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) { if absolutePath, _ := gfile.Search(source); absolutePath != "" { source = absolutePath } - intlog.Printf("Open: %s", source) + intlog.Printf(d.GetCtx(), "Open: %s", source) if db, err := sql.Open("sqlite3", source); err == nil { return db, nil } else { @@ -55,10 +55,10 @@ func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) { } } -// FilteredLinkInfo retrieves and returns filtered `linkInfo` that can be using for +// FilteredLink retrieves and returns filtered `linkInfo` that can be using for // logging or tracing purpose. -func (d *DriverSqlite) FilteredLinkInfo() string { - return d.GetConfig().LinkInfo +func (d *DriverSqlite) FilteredLink() string { + return d.GetConfig().Link } // GetChars returns the security char for this type of database. @@ -66,11 +66,9 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) { return "`", "`" } -// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -// TODO 需要增加对Save方法的支持,可使用正则来实现替换, -// TODO 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE) -func (d *DriverSqlite) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { - return sql, args +// DoCommit deals with the sql string before commits it to underlying sql driver. +func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + return d.Core.DoCommit(ctx, link, sql, args) } // Tables retrieves and returns the tables of current schema. diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index fc2c4d1ca..e2a0b4c17 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -142,8 +142,8 @@ func GetInsertOperationByOption(option int) string { // ConvertDataForTableRecord is a very important function, which does converting for any data that // will be inserted into table as a record. // -// The parameter `obj` should be type of *map/map/*struct/struct. -// It supports inherit struct definition for struct. +// The parameter `value` should be type of *map/map/*struct/struct. +// It supports embedded struct definition for struct. func ConvertDataForTableRecord(value interface{}) map[string]interface{} { var ( rvValue reflect.Value @@ -164,12 +164,15 @@ func ConvertDataForTableRecord(value interface{}) map[string]interface{} { // Convert the value to JSON. data[k], _ = json.Marshal(v) } + case reflect.Struct: switch v.(type) { case time.Time, *time.Time, gtime.Time, *gtime.Time: continue + case Counter, *Counter: continue + default: // Use string conversion in default. if s, ok := v.(apiString); ok { @@ -186,7 +189,7 @@ func ConvertDataForTableRecord(value interface{}) map[string]interface{} { // DataToMapDeep converts `value` to map type recursively. // The parameter `value` should be type of *map/map/*struct/struct. -// It supports inherit struct definition for struct. +// It supports embedded struct definition for struct. func DataToMapDeep(value interface{}) map[string]interface{} { if v, ok := value.(apiMapStrAny); ok { return v.MapStrAny() @@ -445,7 +448,7 @@ func formatSql(sql string, args []interface{}) (newSql string, newArgs []interfa } // formatWhere formats where statement and its arguments for `Where` and `Having` statements. -func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (newWhere string, newArgs []interface{}) { +func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, schema, table string) (newWhere string, newArgs []interface{}) { var ( buffer = bytes.NewBuffer(nil) rv = reflect.ValueOf(where) @@ -483,7 +486,12 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( }) break } - for key, value := range DataToMapDeep(where) { + // Automatically mapping and filtering the struct attribute. + data := DataToMapDeep(where) + if table != "" { + data, _ = db.GetCore().mappingAndFilterData(schema, table, data, true) + } + for key, value := range data { if omitEmpty && empty.IsEmpty(value) { continue } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index 627051adb..245398260 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -17,10 +17,11 @@ import ( "github.com/gogf/gf/text/gstr" ) -// Model is the DAO for ORM. +// Model is core struct implementing the DAO for ORM. type Model struct { db DB // Underlying DB interface. tx *TX // Underlying TX interface. + rawSql string // rawSql is the raw SQL string which marks a raw SQL based Model not a table based Model. schema string // Custom database schema. linkType int // Mark for operation on master or slave. tablesInit string // Table names when model initialization. @@ -48,6 +49,8 @@ type Model struct { cacheName string // Cache name for custom operation. unscoped bool // Disables soft deleting features when select/delete operations. safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. + onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement. + onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns ON "DUPLICATE KEY UPDATE" statement. } // whereHolder is the holder for where condition preparing. @@ -77,29 +80,32 @@ func (c *Core) Table(tableNameQueryOrStruct ...interface{}) *Model { // Model creates and returns a new ORM model from given schema. // The parameter `tableNameQueryOrStruct` can be more than one table names, and also alias name, like: // 1. Model names: -// Model("user") -// Model("user u") -// Model("user, user_detail") -// Model("user u, user_detail ud") -// 2. Model name with alias: Model("user", "u") +// db.Model("user") +// db.Model("user u") +// db.Model("user, user_detail") +// db.Model("user u, user_detail ud") +// 2. Model name with alias: +// db.Model("user", "u") +// 3. Model name with sub-query: +// db.Model("? AS a, ? AS b", subQuery1, subQuery2) func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { var ( - tableStr string - tableName string - extraArgs []interface{} - tableNames = make([]string, len(tableNameQueryOrStruct)) + tableStr string + tableName string + extraArgs []interface{} ) // Model creation with sub-query. if len(tableNameQueryOrStruct) > 1 { conditionStr := gconv.String(tableNameQueryOrStruct[0]) if gstr.Contains(conditionStr, "?") { tableStr, extraArgs = formatWhere( - c.db, conditionStr, tableNameQueryOrStruct[1:], false, + c.db, conditionStr, tableNameQueryOrStruct[1:], false, "", "", ) } } // Normal model creation. if tableStr == "" { + tableNames := make([]string, len(tableNameQueryOrStruct)) for k, v := range tableNameQueryOrStruct { if s, ok := v.(string); ok { tableNames[k] = s @@ -107,7 +113,6 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { tableNames[k] = tableName } } - if len(tableNames) > 1 { tableStr = fmt.Sprintf( `%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]), @@ -129,17 +134,36 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { } } +// Raw creates and returns a model based on a raw sql not a table. +// Example: +// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result) +func (c *Core) Raw(rawSql string, args ...interface{}) *Model { + model := c.Model() + model.rawSql = rawSql + model.extraArgs = args + return model +} + +// Raw creates and returns a model based on a raw sql not a table. +// Example: +// db.Raw("SELECT * FROM `user` WHERE `name` = ?", "john").Scan(&result) +// See Core.Raw. +func (m *Model) Raw(rawSql string, args ...interface{}) *Model { + model := m.db.Raw(rawSql, args...) + model.db = m.db + model.tx = m.tx + return model +} + +func (tx *TX) Raw(rawSql string, args ...interface{}) *Model { + return tx.Model().Raw(rawSql, args...) +} + // With creates and returns an ORM model based on meta data of given object. func (c *Core) With(objects ...interface{}) *Model { return c.db.Model().With(objects...) } -// Table is alias of tx.Model. -// Deprecated, use Model instead. -func (tx *TX) Table(tableNameQueryOrStruct ...interface{}) *Model { - return tx.Model(tableNameQueryOrStruct...) -} - // Model acts like Core.Model except it operates on transaction. // See Core.Model. func (tx *TX) Model(tableNameQueryOrStruct ...interface{}) *Model { diff --git a/database/gdb/gdb_model_condition.go b/database/gdb/gdb_model_condition.go index 7a7d41e5e..cf991c6d3 100644 --- a/database/gdb/gdb_model_condition.go +++ b/database/gdb/gdb_model_condition.go @@ -62,6 +62,8 @@ func (m *Model) WherePri(where interface{}, args ...interface{}) *Model { } // Wheref builds condition string using fmt.Sprintf and arguments. +// Note that if the number of `args` is more than the place holder in `format`, +// the extra `args` will be used as the where condition arguments of the Model. func (m *Model) Wheref(format string, args ...interface{}) *Model { var ( placeHolderCount = gstr.Count(format, "?") @@ -275,6 +277,9 @@ func (m *Model) Order(orderBy ...string) *Model { return m } model := m.getModel() + if model.orderBy != "" { + model.orderBy += "," + } model.orderBy = m.db.GetCore().QuoteString(strings.Join(orderBy, " ")) return model } @@ -285,6 +290,9 @@ func (m *Model) OrderAsc(column string) *Model { return m } model := m.getModel() + if model.orderBy != "" { + model.orderBy += "," + } model.orderBy = m.db.GetCore().QuoteWord(column) + " ASC" return model } @@ -295,6 +303,9 @@ func (m *Model) OrderDesc(column string) *Model { return m } model := m.getModel() + if model.orderBy != "" { + model.orderBy += "," + } model.orderBy = m.db.GetCore().QuoteWord(column) + " DESC" return model } @@ -375,7 +386,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderWhere: if conditionWhere == "" { newWhere, newArgs := formatWhere( - m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, + m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, m.schema, m.tables, ) if len(newWhere) > 0 { conditionWhere = newWhere @@ -387,7 +398,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderAnd: newWhere, newArgs := formatWhere( - m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, + m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, m.schema, m.tables, ) if len(newWhere) > 0 { if len(conditionWhere) == 0 { @@ -402,7 +413,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderOr: newWhere, newArgs := formatWhere( - m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, + m.db, v.where, v.args, m.option&OptionOmitEmpty > 0, m.schema, m.tables, ) if len(newWhere) > 0 { if len(conditionWhere) == 0 { @@ -419,7 +430,13 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh } // Soft deletion. softDeletingCondition := m.getConditionForSoftDeleting() - if !m.unscoped && softDeletingCondition != "" { + if m.rawSql != "" && conditionWhere != "" { + if gstr.ContainsI(m.rawSql, " WHERE ") { + conditionWhere = " AND " + conditionWhere + } else { + conditionWhere = " WHERE " + conditionWhere + } + } else if !m.unscoped && softDeletingCondition != "" { if conditionWhere == "" { conditionWhere = fmt.Sprintf(` WHERE %s`, softDeletingCondition) } else { @@ -430,6 +447,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh conditionWhere = " WHERE " + conditionWhere } } + // GROUP BY. if m.groupBy != "" { conditionExtra += " GROUP BY " + m.groupBy @@ -437,7 +455,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh // HAVING. if len(m.having) > 0 { havingStr, havingArgs := formatWhere( - m.db, m.having[0], gconv.Interfaces(m.having[1]), m.option&OptionOmitEmpty > 0, + m.db, m.having[0], gconv.Interfaces(m.having[1]), m.option&OptionOmitEmpty > 0, m.schema, m.tables, ) if len(havingStr) > 0 { conditionExtra += " HAVING " + havingStr diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index 09001e7d3..02a60d373 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -8,6 +8,8 @@ package gdb import ( "database/sql" + "fmt" + "github.com/gogf/gf/container/gset" "reflect" "github.com/gogf/gf/errors/gerror" @@ -51,16 +53,20 @@ func (m *Model) Data(data ...interface{}) *Model { switch params := data[0].(type) { case Result: model.data = params.List() + case Record: model.data = params.Map() + case List: list := make(List, len(params)) for k, v := range params { list[k] = gutil.MapCopy(v) } model.data = list + case Map: model.data = gutil.MapCopy(params) + default: var ( rv = reflect.ValueOf(params) @@ -100,6 +106,48 @@ func (m *Model) Data(data ...interface{}) *Model { return model } +// OnDuplicate sets the operations when columns conflicts occurs. +// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice. +// Example: +// OnDuplicate("nickname, age") +// OnDuplicate("nickname", "age") +// OnDuplicate(g.Map{ +// "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"), +// }) +// OnDuplicate(g.Map{ +// "nickname": "passport", +// }) +func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model { + model := m.getModel() + if len(onDuplicate) > 1 { + model.onDuplicate = onDuplicate + } else { + model.onDuplicate = onDuplicate[0] + } + return model +} + +// OnDuplicateEx sets the excluding columns for operations when columns conflicts occurs. +// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// The parameter `onDuplicateEx` can be type of string/map/slice. +// Example: +// OnDuplicateEx("passport, password") +// OnDuplicateEx("passport", "password") +// OnDuplicateEx(g.Map{ +// "passport": "", +// "password": "", +// }) +func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model { + model := m.getModel() + if len(onDuplicateEx) > 1 { + model.onDuplicateEx = onDuplicateEx + } else { + model.onDuplicateEx = onDuplicateEx[0] + } + return model +} + // Insert does "INSERT INTO ..." statement for the model. // The optional parameter `data` is the same as the parameter of Model.Data function, // see Model.Data. @@ -156,7 +204,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) { } // doInsertWithOption inserts data with option parameter. -func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) { +func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err error) { defer func() { if err == nil { m.checkAndRemoveCache() @@ -166,68 +214,206 @@ func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) { return nil, gerror.New("inserting into table with empty data") } var ( + list List nowString = gtime.Now().String() fieldNameCreate = m.getSoftFieldNameCreated() fieldNameUpdate = m.getSoftFieldNameUpdated() fieldNameDelete = m.getSoftFieldNameDeleted() ) - // Batch operation. - if list, ok := m.data.(List); ok { - batch := defaultBatchNumber - if m.batch > 0 { - batch = m.batch - } - newData, err := m.filterDataForInsertOrUpdate(list) - if err != nil { - return nil, err - } - list = newData.(List) - // Automatic handling for creating/updating time. - if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") { - for k, v := range list { - gutil.MapDelete(v, fieldNameCreate, fieldNameUpdate, fieldNameDelete) - if fieldNameCreate != "" { - v[fieldNameCreate] = nowString - } - if fieldNameUpdate != "" { - v[fieldNameUpdate] = nowString - } - list[k] = v - } - } - return m.db.DoBatchInsert( - m.GetCtx(), - m.getLink(true), - m.tables, - newData, - option, - batch, - ) + newData, err := m.filterDataForInsertOrUpdate(m.data) + if err != nil { + return nil, err } - // Single operation. - if data, ok := m.data.(Map); ok { - newData, err := m.filterDataForInsertOrUpdate(data) - if err != nil { - return nil, err + + // It converts any data to List type for inserting. + switch value := newData.(type) { + case Result: + list = value.List() + + case Record: + list = List{value.Map()} + + case List: + list = value + for i, v := range list { + list[i] = ConvertDataForTableRecord(v) } - data = newData.(Map) - // Automatic handling for creating/updating time. - if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") { - gutil.MapDelete(data, fieldNameCreate, fieldNameUpdate, fieldNameDelete) + + case Map: + list = List{ConvertDataForTableRecord(value)} + + default: + var ( + rv = reflect.ValueOf(newData) + kind = rv.Kind() + ) + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + // If it's slice type, it then converts it to List type. + case reflect.Slice, reflect.Array: + list = make(List, rv.Len()) + for i := 0; i < rv.Len(); i++ { + list[i] = ConvertDataForTableRecord(rv.Index(i).Interface()) + } + + case reflect.Map: + list = List{ConvertDataForTableRecord(value)} + + case reflect.Struct: + if v, ok := value.(apiInterfaces); ok { + var ( + array = v.Interfaces() + ) + list = make(List, len(array)) + for i := 0; i < len(array); i++ { + list[i] = ConvertDataForTableRecord(array[i]) + } + } else { + list = List{ConvertDataForTableRecord(value)} + } + + default: + return result, gerror.New(fmt.Sprint("unsupported list type:", kind)) + } + } + + if len(list) < 1 { + return result, gerror.New("data list cannot be empty") + } + + // Automatic handling for creating/updating time. + if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") { + for k, v := range list { + gutil.MapDelete(v, fieldNameCreate, fieldNameUpdate, fieldNameDelete) if fieldNameCreate != "" { - data[fieldNameCreate] = nowString + v[fieldNameCreate] = nowString } if fieldNameUpdate != "" { - data[fieldNameUpdate] = nowString + v[fieldNameUpdate] = nowString + } + list[k] = v + } + } + // Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement. + columnNames := make([]string, 0, len(list[0])) + for k, _ := range list[0] { + columnNames = append(columnNames, k) + } + doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames) + if err != nil { + return result, err + } + + return m.db.DoInsert(m.GetCtx(), m.getLink(true), m.tables, list, doInsertOption) +} + +func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) { + option = DoInsertOption{ + InsertOption: insertOption, + BatchCount: m.getBatch(), + } + if insertOption == insertOptionSave { + onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) + if err != nil { + return option, err + } + var ( + onDuplicateExKeySet = gset.NewStrSetFrom(onDuplicateExKeys) + ) + if m.onDuplicate != nil { + switch m.onDuplicate.(type) { + case Raw, *Raw: + option.OnDuplicateStr = gconv.String(m.onDuplicate) + + default: + var ( + reflectValue = reflect.ValueOf(m.onDuplicate) + reflectKind = reflectValue.Kind() + ) + for reflectKind == reflect.Ptr { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + switch reflectKind { + case reflect.String: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gstr.SplitAndTrim(reflectValue.String(), ",") { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + + case reflect.Map: + option.OnDuplicateMap = make(map[string]interface{}) + for k, v := range gconv.Map(m.onDuplicate) { + if onDuplicateExKeySet.Contains(k) { + continue + } + option.OnDuplicateMap[k] = v + } + + case reflect.Slice, reflect.Array: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gconv.Strings(m.onDuplicate) { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + + default: + return option, gerror.Newf(`unsupported OnDuplicate parameter type "%s"`, reflect.TypeOf(m.onDuplicate)) + } + } + } else if onDuplicateExKeySet.Size() > 0 { + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range columnNames { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v } } - return m.db.DoInsert( - m.GetCtx(), - m.getLink(true), - m.tables, - newData, - option, - ) } - return nil, gerror.New("inserting into table with invalid data type") + return +} + +func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, error) { + if onDuplicateEx == nil { + return nil, nil + } + + var ( + reflectValue = reflect.ValueOf(onDuplicateEx) + reflectKind = reflectValue.Kind() + ) + for reflectKind == reflect.Ptr { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + switch reflectKind { + case reflect.String: + return gstr.SplitAndTrim(reflectValue.String(), ","), nil + + case reflect.Map: + return gutil.Keys(onDuplicateEx), nil + + case reflect.Slice, reflect.Array: + return gconv.Strings(onDuplicateEx), nil + + default: + return nil, gerror.Newf(`unsupported OnDuplicateEx parameter type "%s"`, reflect.TypeOf(onDuplicateEx)) + } +} + +func (m *Model) getBatch() int { + batch := defaultBatchNumber + if m.batch > 0 { + batch = m.batch + } + return batch } diff --git a/database/gdb/gdb_model_join.go b/database/gdb/gdb_model_join.go index 10e4c9255..faf993d0e 100644 --- a/database/gdb/gdb_model_join.go +++ b/database/gdb/gdb_model_join.go @@ -56,9 +56,9 @@ func (m *Model) InnerJoin(table ...string) *Model { // doJoin does "LEFT/RIGHT/INNER JOIN ... ON ..." statement on the model. // The parameter `table` can be joined table and its joined condition, // and also with its alias name, like: -// Table("user").InnerJoin("user_detail", "user_detail.uid=user.uid") -// Table("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") -// Table("user", "u").InnerJoin("SELECT xxx FROM xxx AS a", "a.uid=u.uid") +// Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid") +// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").InnerJoin("SELECT xxx FROM xxx AS a", "a.uid=u.uid") // Related issues: // https://github.com/gogf/gf/issues/1024 func (m *Model) doJoin(operator string, table ...string) *Model { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index ff366f218..cd39c4332 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -8,6 +8,7 @@ package gdb import ( "fmt" + "github.com/gogf/gf/errors/gerror" "reflect" "github.com/gogf/gf/container/gset" @@ -18,11 +19,6 @@ import ( "github.com/gogf/gf/util/gconv" ) -const ( - queryTypeNormal = "NormalQuery" - queryTypeCount = "CountQuery" -) - // Select is alias of Model.All. // See Model.All. // Deprecated, use All instead. @@ -200,6 +196,15 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) { return all.Array(), nil } +// Struct retrieves one record from table and converts it into given struct. +// The parameter `pointer` should be type of *struct/**struct. If type **struct is given, +// it can create the struct internally during converting. +// +// Deprecated, use Scan instead. +func (m *Model) Struct(pointer interface{}, where ...interface{}) error { + return m.doStruct(pointer, where...) +} + // Struct retrieves one record from table and converts it into given struct. // The parameter `pointer` should be type of *struct/**struct. If type **struct is given, // it can create the struct internally during converting. @@ -207,24 +212,38 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) { // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. // -// Note that it returns sql.ErrNoRows if there's no record retrieved with the given conditions -// from table and `pointer` is not nil. +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. // -// Eg: +// Example: // user := new(User) -// err := db.Model("user").Where("id", 1).Struct(user) +// err := db.Model("user").Where("id", 1).Scan(user) // // user := (*User)(nil) -// err := db.Model("user").Where("id", 1).Struct(&user) -func (m *Model) Struct(pointer interface{}, where ...interface{}) error { - one, err := m.One(where...) +// err := db.Model("user").Where("id", 1).Scan(&user) +func (m *Model) doStruct(pointer interface{}, where ...interface{}) error { + model := m + // Auto selecting fields by struct attributes. + if model.fieldsEx == "" && (model.fields == "" || model.fields == "*") { + model = m.Fields(pointer) + } + one, err := model.One(where...) if err != nil { return err } if err = one.Struct(pointer); err != nil { return err } - return m.doWithScanStruct(pointer) + return model.doWithScanStruct(pointer) +} + +// Structs retrieves records from table and converts them into given struct slice. +// The parameter `pointer` should be type of *[]struct/*[]*struct. It can create and fill the struct +// slice internally during converting. +// +// Deprecated, use Scan instead. +func (m *Model) Structs(pointer interface{}, where ...interface{}) error { + return m.doStructs(pointer, where...) } // Structs retrieves records from table and converts them into given struct slice. @@ -234,37 +253,45 @@ func (m *Model) Struct(pointer interface{}, where ...interface{}) error { // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. // -// Note that it returns sql.ErrNoRows if there's no record retrieved with the given conditions -// from table and `pointer` is not empty. +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. // -// Eg: +// Example: // users := ([]User)(nil) -// err := db.Model("user").Structs(&users) +// err := db.Model("user").Scan(&users) // // users := ([]*User)(nil) -// err := db.Model("user").Structs(&users) -func (m *Model) Structs(pointer interface{}, where ...interface{}) error { - all, err := m.All(where...) +// err := db.Model("user").Scan(&users) +func (m *Model) doStructs(pointer interface{}, where ...interface{}) error { + model := m + // Auto selecting fields by struct attributes. + if model.fieldsEx == "" && (model.fields == "" || model.fields == "*") { + model = m.Fields( + reflect.New( + reflect.ValueOf(pointer).Elem().Type().Elem(), + ).Interface(), + ) + } + all, err := model.All(where...) if err != nil { return err } if err = all.Structs(pointer); err != nil { return err } - return m.doWithScanStructs(pointer) + return model.doWithScanStructs(pointer) } // Scan automatically calls Struct or Structs function according to the type of parameter `pointer`. -// It calls function Struct if `pointer` is type of *struct/**struct. -// It calls function Structs if `pointer` is type of *[]struct/*[]*struct. +// It calls function doStruct if `pointer` is type of *struct/**struct. +// It calls function doStructs if `pointer` is type of *[]struct/*[]*struct. // -// The optional parameter `where` is the same as the parameter of Model.Where function, -// see Model.Where. +// The optional parameter `where` is the same as the parameter of Model.Where function, see Model.Where. // -// Note that it returns sql.ErrNoRows if there's no record retrieved with the given conditions -// from table. +// Note that it returns sql.ErrNoRows if the given parameter `pointer` pointed to a variable that has +// default value and there's no record retrieved with the given conditions from table. // -// Eg: +// Example: // user := new(User) // err := db.Model("user").Where("id", 1).Scan(user) // @@ -277,16 +304,35 @@ func (m *Model) Structs(pointer interface{}, where ...interface{}) error { // users := ([]*User)(nil) // err := db.Model("user").Scan(&users) func (m *Model) Scan(pointer interface{}, where ...interface{}) error { - var reflectType reflect.Type + var ( + reflectValue reflect.Value + reflectKind reflect.Kind + ) if v, ok := pointer.(reflect.Value); ok { - reflectType = v.Type() + reflectValue = v } else { - reflectType = reflect.TypeOf(pointer) + reflectValue = reflect.ValueOf(pointer) } - if gstr.Contains(reflectType.String(), "[]") { - return m.Structs(pointer, where...) + + reflectKind = reflectValue.Kind() + if reflectKind != reflect.Ptr { + return gerror.New(`the parameter "pointer" for function Scan should type of pointer`) + } + for reflectKind == reflect.Ptr { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + + switch reflectKind { + case reflect.Slice, reflect.Array: + return m.doStructs(pointer, where...) + + case reflect.Struct, reflect.Invalid: + return m.doStruct(pointer, where...) + + default: + return gerror.New(`element of parameter "pointer" for function Scan should type of struct/*struct/[]struct/[]*struct`) } - return m.Struct(pointer, where...) } // ScanList converts `r` to struct slice which contains other complex struct attributes. @@ -458,6 +504,16 @@ func (m *Model) FindScan(pointer interface{}, where ...interface{}) error { return m.Scan(pointer) } +// Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement for the model. +func (m *Model) Union(unions ...*Model) *Model { + return m.db.Union(unions...) +} + +// UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement for the model. +func (m *Model) UnionAll(unions ...*Model) *Model { + return m.db.UnionAll(unions...) +} + // doGetAllBySql does the select statement on the database. func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, err error) { cacheKey := "" @@ -490,18 +546,18 @@ func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, e if cacheKey != "" && err == nil { if m.cacheDuration < 0 { if _, err := cacheObj.Remove(cacheKey); err != nil { - intlog.Error(err) + intlog.Error(m.GetCtx(), err) } } else { if err := cacheObj.Set(cacheKey, result, m.cacheDuration); err != nil { - intlog.Error(err) + intlog.Error(m.GetCtx(), err) } } } return result, err } -func (m *Model) getFormattedSqlAndArgs(queryType string, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) { +func (m *Model) getFormattedSqlAndArgs(queryType int, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) { switch queryType { case queryTypeCount: countFields := "COUNT(1)" @@ -510,6 +566,11 @@ func (m *Model) getFormattedSqlAndArgs(queryType string, limit1 bool) (sqlWithHo // DISTINCT t.user_id uid countFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields) } + // Raw SQL Model. + if m.rawSql != "" { + sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", countFields, m.rawSql) + return sqlWithHolder, nil + } conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true) sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", countFields, m.tables, conditionWhere+conditionExtra) if len(m.groupBy) > 0 { @@ -519,6 +580,15 @@ func (m *Model) getFormattedSqlAndArgs(queryType string, limit1 bool) (sqlWithHo default: conditionWhere, conditionExtra, conditionArgs := m.formatCondition(limit1, false) + // Raw SQL Model, especially for UNION/UNION ALL featured SQL. + if m.rawSql != "" { + sqlWithHolder = fmt.Sprintf( + "%s%s", + m.rawSql, + conditionWhere+conditionExtra, + ) + return sqlWithHolder, conditionArgs + } // DO NOT quote the m.fields where, in case of fields like: // DISTINCT t.user_id uid sqlWithHolder = fmt.Sprintf( diff --git a/database/gdb/gdb_model_time.go b/database/gdb/gdb_model_time.go index e4047e4a1..76c1f0667 100644 --- a/database/gdb/gdb_model_time.go +++ b/database/gdb/gdb_model_time.go @@ -173,6 +173,9 @@ func (m *Model) getConditionOfTableStringForSoftDeleting(s string) string { // getPrimaryTableName parses and returns the primary table name. func (m *Model) getPrimaryTableName() string { + if m.tables == "" { + return "" + } array1 := gstr.SplitAndTrim(m.tables, ",") array2 := gstr.SplitAndTrim(array1[0], " ") array3 := gstr.SplitAndTrim(array2[0], ".") diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index b32d1f5f8..f72e622ff 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -93,17 +93,19 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro } // Increment increments a column's value by a given amount. -func (m *Model) Increment(column string, amount float64) (sql.Result, error) { +// The parameter `amount` can be type of float or integer. +func (m *Model) Increment(column string, amount interface{}) (sql.Result, error) { return m.getModel().Data(column, &Counter{ Field: column, - Value: amount, + Value: gconv.Float64(amount), }).Update() } // Decrement decrements a column's value by a given amount. -func (m *Model) Decrement(column string, amount float64) (sql.Result, error) { +// The parameter `amount` can be type of float or integer. +func (m *Model) Decrement(column string, amount interface{}) (sql.Result, error) { return m.getModel().Data(column, &Counter{ Field: column, - Value: -amount, + Value: -gconv.Float64(amount), }).Update() } diff --git a/database/gdb/gdb_statement.go b/database/gdb/gdb_statement.go index 81f14affe..51d8a4b2f 100644 --- a/database/gdb/gdb_statement.go +++ b/database/gdb/gdb_statement.go @@ -37,7 +37,7 @@ const ( ) // doStmtCommit commits statement according to given `stmtType`. -func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) { +func (s *Stmt) doStmtCommit(ctx context.Context, stmtType string, args ...interface{}) (result interface{}, err error) { var ( cancelFuncForTimeout context.CancelFunc timestampMilli1 = gtime.TimestampMilli() @@ -86,7 +86,7 @@ func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interf // ExecContext executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...) + result, err := s.doStmtCommit(ctx, stmtTypeExecContext, args...) if result != nil { return result.(sql.Result), err } @@ -96,7 +96,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result // QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { - result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...) + result, err := s.doStmtCommit(ctx, stmtTypeQueryContext, args...) if result != nil { return result.(*sql.Rows), err } @@ -110,7 +110,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows // Otherwise, the *Row's Scan scans the first selected row and discards // the rest. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { - result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...) + result, _ := s.doStmtCommit(ctx, stmtTypeQueryRowContext, args...) if result != nil { return result.(*sql.Row) } diff --git a/database/gdb/gdb_type_record.go b/database/gdb/gdb_type_record.go index 3685dd296..a660ec4ef 100644 --- a/database/gdb/gdb_type_record.go +++ b/database/gdb/gdb_type_record.go @@ -52,7 +52,7 @@ func (r Record) Struct(pointer interface{}) error { } return nil } - return gconv.StructTag(r.Map(), pointer, OrmTagForStruct) + return gconv.StructTag(r, pointer, OrmTagForStruct) } // IsEmpty checks and returns whether `r` is empty. diff --git a/database/gdb/gdb_type_record_deprecated.go b/database/gdb/gdb_type_record_deprecated.go deleted file mode 100644 index d07a79807..000000000 --- a/database/gdb/gdb_type_record_deprecated.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package gdb - -// Deprecated, use Json instead. -func (r Record) ToJson() string { - return r.Json() -} - -// Deprecated, use Xml instead. -func (r Record) ToXml(rootTag ...string) string { - return r.Xml(rootTag...) -} - -// Deprecated, use Map instead. -func (r Record) ToMap() Map { - return r.Map() -} - -// Deprecated, use Struct instead. -func (r Record) ToStruct(pointer interface{}) error { - return r.Struct(pointer) -} diff --git a/database/gdb/gdb_type_result.go b/database/gdb/gdb_type_result.go index a5800ece9..489e3bc47 100644 --- a/database/gdb/gdb_type_result.go +++ b/database/gdb/gdb_type_result.go @@ -153,7 +153,7 @@ func (r Result) MapKeyUint(key string) map[uint]Map { return m } -// RecordKeyInt converts `r` to a map[int]Record of which key is specified by `key`. +// RecordKeyStr converts `r` to a map[string]Record of which key is specified by `key`. func (r Result) RecordKeyStr(key string) map[string]Record { m := make(map[string]Record) for _, item := range r { @@ -189,5 +189,5 @@ func (r Result) RecordKeyUint(key string) map[uint]Record { // Structs converts `r` to struct slice. // Note that the parameter `pointer` should be type of *[]struct/*[]*struct. func (r Result) Structs(pointer interface{}) (err error) { - return gconv.StructsTag(r.List(), pointer, OrmTagForStruct) + return gconv.StructsTag(r, pointer, OrmTagForStruct) } diff --git a/database/gdb/gdb_type_result_deprecated.go b/database/gdb/gdb_type_result_deprecated.go deleted file mode 100644 index 8f6faf3da..000000000 --- a/database/gdb/gdb_type_result_deprecated.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package gdb - -// Deprecated, use Json instead. -func (r Result) ToJson() string { - return r.Json() -} - -// Deprecated, use Xml instead. -func (r Result) ToXml(rootTag ...string) string { - return r.Xml(rootTag...) -} - -// Deprecated, use List instead. -func (r Result) ToList() List { - return r.List() -} - -// Deprecated, use MapKeyStr instead. -func (r Result) ToStringMap(key string) map[string]Map { - return r.MapKeyStr(key) -} - -// Deprecated, use MapKetInt instead. -func (r Result) ToIntMap(key string) map[int]Map { - return r.MapKeyInt(key) -} - -// Deprecated, use MapKeyUint instead. -func (r Result) ToUintMap(key string) map[uint]Map { - return r.MapKeyUint(key) -} - -// Deprecated, use RecordKeyStr instead. -func (r Result) ToStringRecord(key string) map[string]Record { - return r.RecordKeyStr(key) -} - -// Deprecated, use RecordKetInt instead. -func (r Result) ToIntRecord(key string) map[int]Record { - return r.RecordKeyInt(key) -} - -// Deprecated, use RecordKetUint instead. -func (r Result) ToUintRecord(key string) map[uint]Record { - return r.RecordKeyUint(key) -} - -// Deprecated, use Structs instead. -func (r Result) ToStructs(pointer interface{}) (err error) { - return r.Structs(pointer) -} diff --git a/database/gdb/gdb_z_driver_test.go b/database/gdb/gdb_z_driver_test.go index e166856d9..848a96a04 100644 --- a/database/gdb/gdb_z_driver_test.go +++ b/database/gdb/gdb_z_driver_test.go @@ -18,9 +18,9 @@ import ( // MyDriver is a custom database driver, which is used for testing only. // For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver -// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit. -// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and -// then gdb.DriverMysql.HandleSqlBeforeCommit. +// gdb.DriverMysql and overwrites its function DoCommit. +// So if there's any sql execution, it goes through MyDriver.DoCommit firstly and +// then gdb.DriverMysql.DoCommit. // You can call it sql "HOOK" or "HiJack" as your will. type MyDriver struct { *gdb.DriverMysql @@ -41,11 +41,11 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) { }, nil } -// HandleSqlBeforeCommit handles the sql before posts it to database. +// DoCommit handles the sql before posts it to database. // It here overwrites the same method of gdb.DriverMysql and makes some custom changes. -func (d *MyDriver) HandleSqlBeforeCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (string, []interface{}) { +func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { latestSqlString.Set(sql) - return d.DriverMysql.HandleSqlBeforeCommit(ctx, link, sql, args) + return d.DriverMysql.DoCommit(ctx, link, sql, args) } func init() { diff --git a/database/gdb/gdb_z_init_test.go b/database/gdb/gdb_z_init_test.go index abe205257..9e73bf49c 100644 --- a/database/gdb/gdb_z_init_test.go +++ b/database/gdb/gdb_z_init_test.go @@ -7,6 +7,7 @@ package gdb_test import ( + "context" "fmt" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/frame/g" @@ -29,9 +30,10 @@ const ( ) var ( - db gdb.DB - dbPrefix gdb.DB - configNode gdb.ConfigNode + db gdb.DB + dbPrefix gdb.DB + dbCtxStrict gdb.DB + configNode gdb.ConfigNode ) func init() { @@ -56,9 +58,15 @@ func init() { } nodePrefix := configNode nodePrefix.Prefix = TableNamePrefix1 + + nodeCtxStrict := configNode + nodeCtxStrict.CtxStrict = true + gdb.AddConfigNode("test", configNode) gdb.AddConfigNode("prefix", nodePrefix) + gdb.AddConfigNode("ctxstrict", nodeCtxStrict) gdb.AddConfigNode(gdb.DefaultGroupName, configNode) + // Default db. if r, err := gdb.New(); err != nil { gtest.Error(err) @@ -87,6 +95,20 @@ func init() { gtest.Error(err) } dbPrefix.SetSchema(TestSchema1) + + // CtxStrict db. + if r, err := gdb.New("ctxstrict"); err != nil { + gtest.Error(err) + } else { + dbCtxStrict = r + } + if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema1)); err != nil { + gtest.Error(err) + } + if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema2)); err != nil { + gtest.Error(err) + } + dbCtxStrict.SetSchema(TestSchema1) } func createTable(table ...string) string { @@ -111,7 +133,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { switch configNode.Type { case "sqlite": - if _, err := db.Exec(fmt.Sprintf(` + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(` CREATE TABLE %s ( id bigint unsigned NOT NULL AUTO_INCREMENT, passport varchar(45), @@ -124,7 +146,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { gtest.Fatal(err) } case "pgsql": - if _, err := db.Exec(fmt.Sprintf(` + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(` CREATE TABLE %s ( id bigint NOT NULL, passport varchar(45), @@ -137,7 +159,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { gtest.Fatal(err) } case "mssql": - if _, err := db.Exec(fmt.Sprintf(` + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(` IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='%s' and xtype='U') CREATE TABLE %s ( ID numeric(10,0) NOT NULL, @@ -151,7 +173,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { gtest.Fatal(err) } case "oracle": - if _, err := db.Exec(fmt.Sprintf(` + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(` CREATE TABLE %s ( ID NUMBER(10) NOT NULL, PASSPORT VARCHAR(45) NOT NULL, @@ -164,7 +186,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { gtest.Fatal(err) } case "mysql": - if _, err := db.Exec(fmt.Sprintf(` + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(` CREATE TABLE %s ( id int(10) unsigned NOT NULL AUTO_INCREMENT, passport varchar(45) NULL, @@ -195,7 +217,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) { }) } - result, err := db.BatchInsert(name, array.Slice()) + result, err := db.Ctx(context.TODO()).Insert(name, array.Slice()) gtest.AssertNil(err) n, e := result.RowsAffected() @@ -205,7 +227,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) { } func dropTableWithDb(db gdb.DB, table string) { - if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { + if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { gtest.Error(err) } } diff --git a/database/gdb/gdb_z_mysql_ctx_test.go b/database/gdb/gdb_z_mysql_ctx_test.go index 33eb23125..0d3cf753a 100644 --- a/database/gdb/gdb_z_mysql_ctx_test.go +++ b/database/gdb/gdb_z_mysql_ctx_test.go @@ -62,3 +62,23 @@ func Test_Ctx_Model(t *testing.T) { db.Model(table).All() }) } + +func Test_Ctx_Strict(t *testing.T) { + table := createInitTableWithDb(dbCtxStrict) + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + _, err := dbCtxStrict.Query("select 1") + t.AssertNE(err, nil) + }) + gtest.C(t, func(t *gtest.T) { + r, err := dbCtxStrict.Model(table).All() + t.AssertNE(err, nil) + t.Assert(len(r), 0) + }) + gtest.C(t, func(t *gtest.T) { + r, err := dbCtxStrict.Model(table).Ctx(context.TODO()).All() + t.AssertNil(err) + t.Assert(len(r), TableSize) + }) +} diff --git a/database/gdb/gdb_z_mysql_method_test.go b/database/gdb/gdb_z_mysql_method_test.go index d83f2f364..5ac4e2730 100644 --- a/database/gdb/gdb_z_mysql_method_test.go +++ b/database/gdb/gdb_z_mysql_method_test.go @@ -329,7 +329,7 @@ func Test_DB_BatchInsert(t *testing.T) { gtest.C(t, func(t *gtest.T) { table := createTable() defer dropTable(table) - r, err := db.BatchInsert(table, g.List{ + r, err := db.Insert(table, g.List{ { "id": 2, "passport": "t2", @@ -357,7 +357,7 @@ func Test_DB_BatchInsert(t *testing.T) { table := createTable() defer dropTable(table) // []interface{} - r, err := db.BatchInsert(table, g.Slice{ + r, err := db.Insert(table, g.Slice{ g.Map{ "id": 2, "passport": "t2", @@ -382,7 +382,7 @@ func Test_DB_BatchInsert(t *testing.T) { gtest.C(t, func(t *gtest.T) { table := createTable() defer dropTable(table) - result, err := db.BatchInsert(table, g.Map{ + result, err := db.Insert(table, g.Map{ "id": 1, "passport": "t1", "password": "p1", @@ -416,7 +416,7 @@ func Test_DB_BatchInsert_Struct(t *testing.T) { NickName: "T1", CreateTime: gtime.Now(), } - result, err := db.BatchInsert(table, user) + result, err := db.Insert(table, user) t.AssertNil(err) n, _ := result.RowsAffected() t.Assert(n, 1) @@ -584,7 +584,7 @@ func Test_DB_GetStruct(t *testing.T) { CreateTime gtime.Time } user := new(User) - err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3) + err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3) t.AssertNil(err) t.Assert(user.NickName, "name_3") }) @@ -597,7 +597,7 @@ func Test_DB_GetStruct(t *testing.T) { CreateTime *gtime.Time } user := new(User) - err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3) + err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3) t.AssertNil(err) t.Assert(user.NickName, "name_3") }) @@ -615,7 +615,7 @@ func Test_DB_GetStructs(t *testing.T) { CreateTime gtime.Time } var users []User - err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1) + err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1) t.AssertNil(err) t.Assert(len(users), TableSize-1) t.Assert(users[0].Id, 2) @@ -635,7 +635,7 @@ func Test_DB_GetStructs(t *testing.T) { CreateTime *gtime.Time } var users []User - err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1) + err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1) t.AssertNil(err) t.Assert(len(users), TableSize-1) t.Assert(users[0].Id, 2) @@ -1283,7 +1283,7 @@ func Test_DB_Prefix(t *testing.T) { }) } - result, err := db.BatchInsert(name, array.Slice()) + result, err := db.Insert(name, array.Slice()) t.AssertNil(err) n, e := result.RowsAffected() diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 7370e5bfc..923ccb68a 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -899,7 +899,7 @@ func Test_Model_Struct(t *testing.T) { CreateTime gtime.Time } user := new(User) - err := db.Model(table).Where("id=1").Struct(user) + err := db.Model(table).Where("id=1").Scan(user) t.AssertNil(err) t.Assert(user.NickName, "name_1") t.Assert(user.CreateTime.String(), "2018-10-24 10:00:00") @@ -913,7 +913,7 @@ func Test_Model_Struct(t *testing.T) { CreateTime *gtime.Time } user := new(User) - err := db.Model(table).Where("id=1").Struct(user) + err := db.Model(table).Where("id=1").Scan(user) t.AssertNil(err) t.Assert(user.NickName, "name_1") t.Assert(user.CreateTime.String(), "2018-10-24 10:00:00") @@ -928,7 +928,7 @@ func Test_Model_Struct(t *testing.T) { CreateTime *gtime.Time } user := (*User)(nil) - err := db.Model(table).Where("id=1").Struct(&user) + err := db.Model(table).Where("id=1").Scan(&user) t.AssertNil(err) t.Assert(user.NickName, "name_1") t.Assert(user.CreateTime.String(), "2018-10-24 10:00:00") @@ -960,7 +960,7 @@ func Test_Model_Struct(t *testing.T) { CreateTime *gtime.Time } user := new(User) - err := db.Model(table).Where("id=-1").Struct(user) + err := db.Model(table).Where("id=-1").Scan(user) t.Assert(err, sql.ErrNoRows) }) gtest.C(t, func(t *gtest.T) { @@ -972,7 +972,7 @@ func Test_Model_Struct(t *testing.T) { CreateTime *gtime.Time } var user *User - err := db.Model(table).Where("id=-1").Struct(&user) + err := db.Model(table).Where("id=-1").Scan(&user) t.AssertNil(err) }) } @@ -992,7 +992,7 @@ func Test_Model_Struct_CustomType(t *testing.T) { CreateTime gtime.Time } user := new(User) - err := db.Model(table).Where("id=1").Struct(user) + err := db.Model(table).Where("id=1").Scan(user) t.AssertNil(err) t.Assert(user.NickName, "name_1") t.Assert(user.CreateTime.String(), "2018-10-24 10:00:00") @@ -1012,7 +1012,7 @@ func Test_Model_Structs(t *testing.T) { CreateTime gtime.Time } var users []User - err := db.Model(table).Order("id asc").Structs(&users) + err := db.Model(table).Order("id asc").Scan(&users) if err != nil { gtest.Error(err) } @@ -1035,7 +1035,7 @@ func Test_Model_Structs(t *testing.T) { CreateTime *gtime.Time } var users []*User - err := db.Model(table).Order("id asc").Structs(&users) + err := db.Model(table).Order("id asc").Scan(&users) if err != nil { gtest.Error(err) } @@ -1081,38 +1081,40 @@ func Test_Model_Structs(t *testing.T) { CreateTime *gtime.Time } var users []*User - err := db.Model(table).Where("id<0").Structs(&users) + err := db.Model(table).Where("id<0").Scan(&users) t.AssertNil(err) }) } -func Test_Model_StructsWithJsonTag(t *testing.T) { - table := createInitTable() - defer dropTable(table) - - gtest.C(t, func(t *gtest.T) { - type User struct { - Uid int `json:"id"` - Passport string - Password string - Name string `json:"nick_name"` - Time gtime.Time `json:"create_time"` - } - var users []User - err := db.Model(table).Order("id asc").Structs(&users) - if err != nil { - gtest.Error(err) - } - t.Assert(len(users), TableSize) - t.Assert(users[0].Uid, 1) - t.Assert(users[1].Uid, 2) - t.Assert(users[2].Uid, 3) - t.Assert(users[0].Name, "name_1") - t.Assert(users[1].Name, "name_2") - t.Assert(users[2].Name, "name_3") - t.Assert(users[0].Time.String(), "2018-10-24 10:00:00") - }) -} +// JSON tag is only used for JSON Marshal/Unmarshal, DO NOT use it in multiple purposes! +//func Test_Model_StructsWithJsonTag(t *testing.T) { +// table := createInitTable() +// defer dropTable(table) +// +// db.SetDebug(true) +// gtest.C(t, func(t *gtest.T) { +// type User struct { +// Uid int `json:"id"` +// Passport string +// Password string +// Name string `json:"nick_name"` +// Time gtime.Time `json:"create_time"` +// } +// var users []User +// err := db.Model(table).Order("id asc").Scan(&users) +// if err != nil { +// gtest.Error(err) +// } +// t.Assert(len(users), TableSize) +// t.Assert(users[0].Uid, 1) +// t.Assert(users[1].Uid, 2) +// t.Assert(users[2].Uid, 3) +// t.Assert(users[0].Name, "name_1") +// t.Assert(users[1].Name, "name_2") +// t.Assert(users[2].Name, "name_3") +// t.Assert(users[0].Time.String(), "2018-10-24 10:00:00") +// }) +//} func Test_Model_Scan(t *testing.T) { table := createInitTable() @@ -1469,11 +1471,11 @@ func Test_Model_Where(t *testing.T) { t.Assert(len(result), 3) t.Assert(result[0]["id"].Int(), 1) }) - // struct + // struct, automatic mapping and filtering. gtest.C(t, func(t *gtest.T) { type User struct { - Id int `json:"id"` - Nickname string `gconv:"nickname"` + Id int + Nickname string } result, err := db.Model(table).Where(User{3, "name_3"}).One() t.AssertNil(err) @@ -3098,7 +3100,7 @@ func Test_TimeZoneInsert(t *testing.T) { gtest.C(t, func(t *gtest.T) { _, _ = db.Model(tableName).Unscoped().Insert(u) userEntity := &User{} - err := db.Model(tableName).Where("id", 1).Unscoped().Struct(&userEntity) + err := db.Model(tableName).Where("id", 1).Unscoped().Scan(&userEntity) t.AssertNil(err) t.Assert(userEntity.CreatedAt.String(), "2020-11-22 04:23:45") t.Assert(userEntity.UpdatedAt.String(), "2020-11-22 05:23:45") @@ -3129,7 +3131,7 @@ func Test_Model_Fields_Map_Struct(t *testing.T) { XXX_TYPE int } var a = A{} - err := db.Model(table).Fields(a).Where("id", 1).Struct(&a) + err := db.Model(table).Fields(a).Where("id", 1).Scan(&a) t.AssertNil(err) t.Assert(a.ID, 1) t.Assert(a.PASSPORT, "user_1") @@ -3143,7 +3145,7 @@ func Test_Model_Fields_Map_Struct(t *testing.T) { XXX_TYPE int } var a *A - err := db.Model(table).Fields(a).Where("id", 1).Struct(&a) + err := db.Model(table).Fields(a).Where("id", 1).Scan(&a) t.AssertNil(err) t.Assert(a.ID, 1) t.Assert(a.PASSPORT, "user_1") @@ -3157,7 +3159,7 @@ func Test_Model_Fields_Map_Struct(t *testing.T) { XXX_TYPE int } var a *A - err := db.Model(table).Fields(&a).Where("id", 1).Struct(&a) + err := db.Model(table).Fields(&a).Where("id", 1).Scan(&a) t.AssertNil(err) t.Assert(a.ID, 1) t.Assert(a.PASSPORT, "user_1") @@ -3551,3 +3553,179 @@ func Test_Model_Increment_Decrement(t *testing.T) { t.Assert(count, 1) }) } + +func Test_Model_OnDuplicate(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicate("passport,password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicate(g.Slice{"passport", "password"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicate(g.Map{ + "passport": "nickname", + "password": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["nickname"]) + t.Assert(one["password"], data["nickname"]) + t.Assert(one["nickname"], "name_1") + }) + + // map+raw. + gtest.C(t, func(t *gtest.T) { + data := g.MapStrStr{ + "id": "1", + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicate(g.Map{ + "passport": gdb.Raw("CONCAT(VALUES(`passport`), '1')"), + "password": gdb.Raw("CONCAT(VALUES(`password`), '2')"), + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]+"1") + t.Assert(one["password"], data["password"]+"2") + t.Assert(one["nickname"], "name_1") + }) +} + +func Test_Model_OnDuplicateEx(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicateEx("nickname,create_time").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicateEx(g.Slice{"nickname", "create_time"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnDuplicateEx(g.Map{ + "nickname": "nickname", + "create_time": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).FindOne(1) + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) +} + +func Test_Model_Raw(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + all, err := db. + Raw(fmt.Sprintf("select * from %s where id in (?)", table), g.Slice{1, 5, 7, 8, 9, 10}). + WhereLT("id", 8). + WhereIn("id", g.Slice{1, 2, 3, 4, 5, 6, 7}). + OrderDesc("id"). + Limit(2). + All() + t.AssertNil(err) + t.Assert(len(all), 2) + t.Assert(all[0]["id"], 7) + t.Assert(all[1]["id"], 5) + }) + + gtest.C(t, func(t *gtest.T) { + count, err := db. + Raw(fmt.Sprintf("select * from %s where id in (?)", table), g.Slice{1, 5, 7, 8, 9, 10}). + WhereLT("id", 8). + WhereIn("id", g.Slice{1, 2, 3, 4, 5, 6, 7}). + OrderDesc("id"). + Limit(2). + Count() + t.AssertNil(err) + t.Assert(count, 6) + }) +} diff --git a/database/gdb/gdb_z_mysql_raw_test.go b/database/gdb/gdb_z_mysql_raw_type_test.go similarity index 100% rename from database/gdb/gdb_z_mysql_raw_test.go rename to database/gdb/gdb_z_mysql_raw_type_test.go diff --git a/database/gdb/gdb_z_mysql_struct_test.go b/database/gdb/gdb_z_mysql_struct_test.go index dbcf48d97..d4d7a7373 100644 --- a/database/gdb/gdb_z_mysql_struct_test.go +++ b/database/gdb/gdb_z_mysql_struct_test.go @@ -8,10 +8,13 @@ package gdb_test import ( "database/sql" + "github.com/gogf/gf/database/gdb" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/test/gtest" "github.com/gogf/gf/util/gconv" + "reflect" "testing" ) @@ -121,7 +124,7 @@ func Test_Struct_Pointer_Attribute(t *testing.T) { }) gtest.C(t, func(t *gtest.T) { user := new(User) - err := db.Model(table).Struct(user, "id=1") + err := db.Model(table).Scan(user, "id=1") t.AssertNil(err) t.Assert(*user.Id, 1) t.Assert(*user.Passport, "user_1") @@ -130,7 +133,7 @@ func Test_Struct_Pointer_Attribute(t *testing.T) { }) gtest.C(t, func(t *gtest.T) { var user *User - err := db.Model(table).Struct(&user, "id=1") + err := db.Model(table).Scan(&user, "id=1") t.AssertNil(err) t.Assert(*user.Id, 1) t.Assert(*user.Passport, "user_1") @@ -201,7 +204,7 @@ func Test_Structs_Pointer_Attribute(t *testing.T) { // Structs gtest.C(t, func(t *gtest.T) { users := make([]User, 0) - err := db.Model(table).Structs(&users, "id < 3") + err := db.Model(table).Scan(&users, "id < 3") t.AssertNil(err) t.Assert(len(users), 2) t.Assert(*users[0].Id, 1) @@ -211,7 +214,7 @@ func Test_Structs_Pointer_Attribute(t *testing.T) { }) gtest.C(t, func(t *gtest.T) { users := make([]*User, 0) - err := db.Model(table).Structs(&users, "id < 3") + err := db.Model(table).Scan(&users, "id < 3") t.AssertNil(err) t.Assert(len(users), 2) t.Assert(*users[0].Id, 1) @@ -221,7 +224,7 @@ func Test_Structs_Pointer_Attribute(t *testing.T) { }) gtest.C(t, func(t *gtest.T) { var users []User - err := db.Model(table).Structs(&users, "id < 3") + err := db.Model(table).Scan(&users, "id < 3") t.AssertNil(err) t.Assert(len(users), 2) t.Assert(*users[0].Id, 1) @@ -231,7 +234,7 @@ func Test_Structs_Pointer_Attribute(t *testing.T) { }) gtest.C(t, func(t *gtest.T) { var users []*User - err := db.Model(table).Structs(&users, "id < 3") + err := db.Model(table).Scan(&users, "id < 3") t.AssertNil(err) t.Assert(len(users), 2) t.Assert(*users[0].Id, 1) @@ -254,7 +257,7 @@ func Test_Struct_Empty(t *testing.T) { gtest.C(t, func(t *gtest.T) { user := new(User) - err := db.Model(table).Where("id=100").Struct(user) + err := db.Model(table).Where("id=100").Scan(user) t.Assert(err, sql.ErrNoRows) t.AssertNE(user, nil) }) @@ -269,7 +272,7 @@ func Test_Struct_Empty(t *testing.T) { gtest.C(t, func(t *gtest.T) { var user *User - err := db.Model(table).Where("id=100").Struct(&user) + err := db.Model(table).Where("id=100").Scan(&user) t.AssertNil(err) t.Assert(user, nil) }) @@ -395,17 +398,17 @@ type User struct { } func (user *User) UnmarshalValue(value interface{}) error { - switch result := value.(type) { - case map[string]interface{}: - user.Id = result["id"].(int) - user.Passport = result["passport"].(string) - user.Password = "" - user.Nickname = result["nickname"].(string) - user.CreateTime = gtime.New(result["create_time"]) + if record, ok := value.(gdb.Record); ok { + *user = User{ + Id: record["id"].Int(), + Passport: record["passport"].String(), + Password: "", + Nickname: record["nickname"].String(), + CreateTime: record["create_time"].GTime(), + } return nil - default: - return gconv.Struct(value, user) } + return gerror.Newf(`unsupported value type for UnmarshalValue: %v`, reflect.TypeOf(value)) } func Test_Model_Scan_UnmarshalValue(t *testing.T) { @@ -452,3 +455,27 @@ func Test_Model_Scan_Map(t *testing.T) { t.Assert(users[9].CreateTime.String(), CreateTime) }) } + +func Test_Scan_AutoFilteringByStructAttributes(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + type User struct { + Id int + Passport string + } + //db.SetDebug(true) + gtest.C(t, func(t *gtest.T) { + var user *User + err := db.Model(table).OrderAsc("id").Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 1) + }) + gtest.C(t, func(t *gtest.T) { + var users []User + err := db.Model(table).OrderAsc("id").Scan(&users) + t.AssertNil(err) + t.Assert(len(users), TableSize) + t.Assert(users[0].Id, 1) + }) +} diff --git a/database/gdb/gdb_z_mysql_transaction_test.go b/database/gdb/gdb_z_mysql_transaction_test.go index 906bc6aca..669055145 100644 --- a/database/gdb/gdb_z_mysql_transaction_test.go +++ b/database/gdb/gdb_z_mysql_transaction_test.go @@ -163,7 +163,7 @@ func Test_TX_BatchInsert(t *testing.T) { if err != nil { gtest.Error(err) } - if _, err := tx.BatchInsert(table, g.List{ + if _, err := tx.Insert(table, g.List{ { "id": 2, "passport": "t", @@ -201,7 +201,7 @@ func Test_TX_BatchReplace(t *testing.T) { if err != nil { gtest.Error(err) } - if _, err := tx.BatchReplace(table, g.List{ + if _, err := tx.Replace(table, g.List{ { "id": 2, "passport": "USER_2", @@ -244,7 +244,7 @@ func Test_TX_BatchSave(t *testing.T) { if err != nil { gtest.Error(err) } - if _, err := tx.BatchSave(table, g.List{ + if _, err := tx.Save(table, g.List{ { "id": 4, "passport": "USER_4", @@ -349,7 +349,7 @@ func Test_TX_Update(t *testing.T) { if err := tx.Commit(); err != nil { gtest.Error(err) } - _, err = tx.Table(table).Fields("create_time").Where("id", 3).Value() + _, err = tx.Model(table).Fields("create_time").Where("id", 3).Value() t.AssertNE(err, nil) if value, err := db.Model(table).Fields("create_time").Where("id", 3).Value(); err != nil { @@ -666,7 +666,6 @@ func Test_TX_GetScan(t *testing.T) { } func Test_TX_Delete(t *testing.T) { - gtest.C(t, func(t *gtest.T) { table := createInitTable() defer dropTable(table) @@ -685,6 +684,8 @@ func Test_TX_Delete(t *testing.T) { } else { t.Assert(n, 0) } + + t.Assert(tx.IsClosed(), true) }) gtest.C(t, func(t *gtest.T) { @@ -697,7 +698,7 @@ func Test_TX_Delete(t *testing.T) { if _, err := tx.Delete(table, 1); err != nil { gtest.Error(err) } - if n, err := tx.Table(table).Count(); err != nil { + if n, err := tx.Model(table).Count(); err != nil { gtest.Error(err) } else { t.Assert(n, 0) @@ -711,6 +712,8 @@ func Test_TX_Delete(t *testing.T) { t.Assert(n, TableSize) t.AssertNE(n, 0) } + + t.Assert(tx.IsClosed(), true) }) } @@ -721,7 +724,7 @@ func Test_Transaction(t *testing.T) { gtest.C(t, func(t *gtest.T) { ctx := context.TODO() err := db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error { - if _, err := tx.Replace(table, g.Map{ + if _, err := tx.Ctx(ctx).Replace(table, g.Map{ "id": 1, "passport": "USER_1", "password": "PASS_1", @@ -730,11 +733,12 @@ func Test_Transaction(t *testing.T) { }); err != nil { t.Error(err) } + t.Assert(tx.IsClosed(), false) return gerror.New("error") }) t.AssertNE(err, nil) - if value, err := db.Model(table).Fields("nickname").Where("id", 1).Value(); err != nil { + if value, err := db.Model(table).Ctx(ctx).Fields("nickname").Where("id", 1).Value(); err != nil { gtest.Error(err) } else { t.Assert(value.String(), "name_1") @@ -956,8 +960,8 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { table := createTable() defer dropTable(table) - db.SetDebug(true) - defer db.SetDebug(false) + //db.SetDebug(true) + //defer db.SetDebug(false) gtest.C(t, func(t *gtest.T) { var ( diff --git a/database/gdb/gdb_z_mysql_types_test.go b/database/gdb/gdb_z_mysql_types_test.go index e26d763b9..7c02df03e 100644 --- a/database/gdb/gdb_z_mysql_types_test.go +++ b/database/gdb/gdb_z_mysql_types_test.go @@ -87,7 +87,7 @@ func Test_Types(t *testing.T) { TinyInt bool } var obj *T - err = db.Model("types").Struct(&obj) + err = db.Model("types").Scan(&obj) t.AssertNil(err) t.Assert(obj.Id, 1) t.Assert(obj.Blob, data["blob"]) diff --git a/database/gdb/gdb_z_mysql_union_test.go b/database/gdb/gdb_z_mysql_union_test.go new file mode 100644 index 000000000..34df871b4 --- /dev/null +++ b/database/gdb/gdb_z_mysql_union_test.go @@ -0,0 +1,146 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb_test + +import ( + "github.com/gogf/gf/frame/g" + "testing" + + "github.com/gogf/gf/test/gtest" +) + +func Test_Union(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Union( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").All() + + t.AssertNil(err) + + t.Assert(len(r), 3) + t.Assert(r[0]["id"], 3) + t.Assert(r[1]["id"], 2) + t.Assert(r[2]["id"], 1) + }) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Union( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").One() + + t.AssertNil(err) + + t.Assert(r["id"], 3) + }) +} + +func Test_UnionAll(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + r, err := db.UnionAll( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").All() + + t.AssertNil(err) + + t.Assert(len(r), 5) + t.Assert(r[0]["id"], 3) + t.Assert(r[1]["id"], 2) + t.Assert(r[2]["id"], 2) + t.Assert(r[3]["id"], 1) + t.Assert(r[4]["id"], 1) + }) + + gtest.C(t, func(t *gtest.T) { + r, err := db.UnionAll( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").One() + + t.AssertNil(err) + + t.Assert(r["id"], 3) + }) +} + +func Test_Model_Union(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Model(table).Union( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").All() + + t.AssertNil(err) + + t.Assert(len(r), 3) + t.Assert(r[0]["id"], 3) + t.Assert(r[1]["id"], 2) + t.Assert(r[2]["id"], 1) + }) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Model(table).Union( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").One() + + t.AssertNil(err) + + t.Assert(r["id"], 3) + }) +} + +func Test_Model_UnionAll(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Model(table).UnionAll( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").All() + + t.AssertNil(err) + + t.Assert(len(r), 5) + t.Assert(r[0]["id"], 3) + t.Assert(r[1]["id"], 2) + t.Assert(r[2]["id"], 2) + t.Assert(r[3]["id"], 1) + t.Assert(r[4]["id"], 1) + }) + + gtest.C(t, func(t *gtest.T) { + r, err := db.Model(table).UnionAll( + db.Model(table).Where("id", 1), + db.Model(table).Where("id", 2), + db.Model(table).WhereIn("id", g.Slice{1, 2, 3}).OrderDesc("id"), + ).OrderDesc("id").One() + + t.AssertNil(err) + + t.Assert(r["id"], 3) + }) +} diff --git a/database/gredis/gredis.go b/database/gredis/gredis.go index 13211d81d..8de774e15 100644 --- a/database/gredis/gredis.go +++ b/database/gredis/gredis.go @@ -114,7 +114,7 @@ func New(config *Config) *Redis { if err != nil { return nil, err } - intlog.Printf(`open new connection, config:%+v`, config) + intlog.Printf(context.TODO(), `open new connection, config:%+v`, config) // AUTH if len(config.Pass) > 0 { if _, err := c.Do("AUTH", config.Pass); err != nil { diff --git a/database/gredis/gredis_config.go b/database/gredis/gredis_config.go index c2f2ba42e..121b9c4a1 100644 --- a/database/gredis/gredis_config.go +++ b/database/gredis/gredis_config.go @@ -7,6 +7,7 @@ package gredis import ( + "context" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" @@ -36,7 +37,7 @@ func SetConfig(config *Config, name ...string) { configs.Set(group, config) instances.Remove(group) - intlog.Printf(`SetConfig for group "%s": %+v`, group, config) + intlog.Printf(context.TODO(), `SetConfig for group "%s": %+v`, group, config) } // SetConfigByStr sets the global configuration for specified group with string. @@ -78,7 +79,7 @@ func RemoveConfig(name ...string) { configs.Remove(group) instances.Remove(group) - intlog.Printf(`RemoveConfig: %s`, group) + intlog.Printf(context.TODO(), `RemoveConfig: %s`, group) } // ConfigFromStr parses and returns config from given str. diff --git a/database/gredis/gredis_conn.go b/database/gredis/gredis_conn.go index f33d4528e..2826affcd 100644 --- a/database/gredis/gredis_conn.go +++ b/database/gredis/gredis_conn.go @@ -8,8 +8,8 @@ package gredis import ( "context" - "errors" "github.com/gogf/gf/container/gvar" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/util/gconv" @@ -50,7 +50,7 @@ func (c *Conn) do(timeout time.Duration, commandName string, args ...interface{} if timeout > 0 { conn, ok := c.Conn.(redis.ConnWithTimeout) if !ok { - return gvar.New(nil), errors.New(`current connection does not support "ConnWithTimeout"`) + return gvar.New(nil), gerror.New(`current connection does not support "ConnWithTimeout"`) } return conn.DoWithTimeout(timeout, commandName, args...) } @@ -107,7 +107,7 @@ func (c *Conn) ReceiveVar() (*gvar.Var, error) { func (c *Conn) ReceiveVarWithTimeout(timeout time.Duration) (*gvar.Var, error) { conn, ok := c.Conn.(redis.ConnWithTimeout) if !ok { - return gvar.New(nil), errors.New(`current connection does not support "ConnWithTimeout"`) + return gvar.New(nil), gerror.New(`current connection does not support "ConnWithTimeout"`) } return resultToVar(conn.ReceiveWithTimeout(timeout)) } diff --git a/database/gredis/gredis_conn_tracing.go b/database/gredis/gredis_conn_tracing.go index 9ee411323..4fc478423 100644 --- a/database/gredis/gredis_conn_tracing.go +++ b/database/gredis/gredis_conn_tracing.go @@ -12,7 +12,6 @@ import ( "github.com/gogf/gf" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/net/gtrace" - "github.com/gogf/gf/os/gcmd" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -38,19 +37,9 @@ const ( tracingEventRedisExecutionArguments = "redis.execution.arguments" ) -var ( - // tracingInternal enables tracing for internal type spans. - // It's true in default. - tracingInternal = true -) - -func init() { - tracingInternal = gcmd.GetOptWithEnv("gf.tracing.internal", true).Bool() -} - // addTracingItem checks and adds redis tracing information to OpenTelemetry. func (c *Conn) addTracingItem(item *tracingItem) { - if !tracingInternal || !gtrace.IsActivated(c.ctx) { + if !gtrace.IsTracingInternal() || !gtrace.IsActivated(c.ctx) { return } tr := otel.GetTracerProvider().Tracer( diff --git a/debug/gdebug/gdebug_caller.go b/debug/gdebug/gdebug_caller.go index 485a61228..fa9bb2115 100644 --- a/debug/gdebug/gdebug_caller.go +++ b/debug/gdebug/gdebug_caller.go @@ -8,6 +8,7 @@ package gdebug import ( "fmt" + "github.com/gogf/gf/internal/utils" "os" "os/exec" "path/filepath" @@ -53,11 +54,13 @@ func Caller(skip ...int) (function string, path string, line int) { // // The parameter <filter> is used to filter the path of the caller. func CallerWithFilter(filter string, skip ...int) (function string, path string, line int) { - number := 0 + var ( + number = 0 + ok = true + ) if len(skip) > 0 { number = skip[0] } - ok := true pc, file, line, start := callerFromIndex([]string{filter}) if start != -1 { for i := start + number; i < maxCallerDepth; i++ { @@ -65,12 +68,6 @@ func CallerWithFilter(filter string, skip ...int) (function string, path string, pc, file, line, ok = runtime.Caller(i) } if ok { - if filter != "" && strings.Contains(file, filter) { - continue - } - if strings.Contains(file, stackFilterKey) { - continue - } function := "" if fn := runtime.FuncForPC(pc); fn == nil { function = "unknown" @@ -104,8 +101,14 @@ func callerFromIndex(filters []string) (pc uintptr, file string, line int, index if filtered { continue } - if strings.Contains(file, stackFilterKey) { - continue + if !utils.IsDebugEnabled() { + if strings.Contains(file, utils.StackFilterKeyForGoFrame) { + continue + } + } else { + if strings.Contains(file, stackFilterKey) { + continue + } } if index > 0 { index-- diff --git a/debug/gdebug/gdebug_stack.go b/debug/gdebug/gdebug_stack.go index 09ac951cd..a9c0a1e9c 100644 --- a/debug/gdebug/gdebug_stack.go +++ b/debug/gdebug/gdebug_stack.go @@ -80,14 +80,17 @@ func StackWithFilters(filters []string, skip ...int) string { if filtered { continue } - if strings.Contains(file, stackFilterKey) { - continue - } + if !utils.IsDebugEnabled() { if strings.Contains(file, utils.StackFilterKeyForGoFrame) { continue } + } else { + if strings.Contains(file, stackFilterKey) { + continue + } } + if fn := runtime.FuncForPC(pc); fn == nil { name = "unknown" } else { diff --git a/encoding/gcharset/gcharset.go b/encoding/gcharset/gcharset.go index c976fa606..0aac86b27 100644 --- a/encoding/gcharset/gcharset.go +++ b/encoding/gcharset/gcharset.go @@ -21,8 +21,7 @@ package gcharset import ( "bytes" - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "io/ioutil" "golang.org/x/text/encoding" @@ -60,11 +59,11 @@ func Convert(dstCharset string, srcCharset string, src string) (dst string, err transform.NewReader(bytes.NewReader([]byte(src)), e.NewDecoder()), ) if err != nil { - return "", fmt.Errorf("%s to utf8 failed. %v", srcCharset, err) + return "", gerror.Newf("%s to utf8 failed. %v", srcCharset, err) } src = string(tmp) } else { - return dst, errors.New(fmt.Sprintf("unsupport srcCharset: %s", srcCharset)) + return dst, gerror.Newf("unsupport srcCharset: %s", srcCharset) } } // Do the converting from UTF-8 to <dstCharset>. @@ -74,11 +73,11 @@ func Convert(dstCharset string, srcCharset string, src string) (dst string, err transform.NewReader(bytes.NewReader([]byte(src)), e.NewEncoder()), ) if err != nil { - return "", fmt.Errorf("utf to %s failed. %v", dstCharset, err) + return "", gerror.Newf("utf to %s failed. %v", dstCharset, err) } dst = string(tmp) } else { - return dst, errors.New(fmt.Sprintf("unsupport dstCharset: %s", dstCharset)) + return dst, gerror.Newf("unsupport dstCharset: %s", dstCharset) } } else { dst = src diff --git a/encoding/gcompress/gcompress_zip.go b/encoding/gcompress/gcompress_zip.go index 72991aff9..dad27820b 100644 --- a/encoding/gcompress/gcompress_zip.go +++ b/encoding/gcompress/gcompress_zip.go @@ -9,6 +9,7 @@ package gcompress import ( "archive/zip" "bytes" + "context" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/text/gstr" @@ -92,7 +93,7 @@ func doZipPathWriter(path string, exclude string, zipWriter *zip.Writer, prefix headerPrefix = strings.Replace(headerPrefix, "//", "/", -1) for _, file := range files { if exclude == file { - intlog.Printf(`exclude file path: %s`, file) + intlog.Printf(context.TODO(), `exclude file path: %s`, file) continue } dir := gfile.Dir(file[len(path):]) diff --git a/encoding/gini/gini.go b/encoding/gini/gini.go index 8ca3770bb..6394d2393 100644 --- a/encoding/gini/gini.go +++ b/encoding/gini/gini.go @@ -10,8 +10,8 @@ package gini import ( "bufio" "bytes" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "io" "strings" @@ -70,7 +70,7 @@ func Decode(data []byte) (res map[string]interface{}, err error) { } if haveSection == false { - return nil, errors.New("failed to parse INI file, section not found") + return nil, gerror.New("failed to parse INI file, section not found") } return res, nil } diff --git a/encoding/gjson/gjson_api.go b/encoding/gjson/gjson_api.go index d065910f2..a45f6a6de 100644 --- a/encoding/gjson/gjson_api.go +++ b/encoding/gjson/gjson_api.go @@ -319,23 +319,11 @@ func (j *Json) GetStruct(pattern string, pointer interface{}, mapping ...map[str return gconv.Struct(j.Get(pattern), pointer, mapping...) } -// GetStructDeep does GetStruct recursively. -// Deprecated, use GetStruct instead. -func (j *Json) GetStructDeep(pattern string, pointer interface{}, mapping ...map[string]string) error { - return gconv.StructDeep(j.Get(pattern), pointer, mapping...) -} - // GetStructs converts any slice to given struct slice. func (j *Json) GetStructs(pattern string, pointer interface{}, mapping ...map[string]string) error { return gconv.Structs(j.Get(pattern), pointer, mapping...) } -// GetStructsDeep converts any slice to given struct slice recursively. -// Deprecated, use GetStructs instead. -func (j *Json) GetStructsDeep(pattern string, pointer interface{}, mapping ...map[string]string) error { - return gconv.StructsDeep(j.Get(pattern), pointer, mapping...) -} - // GetScan automatically calls Struct or Structs function according to the type of parameter // <pointer> to implement the converting.. func (j *Json) GetScan(pattern string, pointer interface{}, mapping ...map[string]string) error { diff --git a/encoding/gjson/gjson_api_encoding.go b/encoding/gjson/gjson_api_encoding.go index d745c7119..a0eabdb90 100644 --- a/encoding/gjson/gjson_api_encoding.go +++ b/encoding/gjson/gjson_api_encoding.go @@ -70,7 +70,7 @@ func (j *Json) MustToJsonIndentString() string { // ======================================================================== func (j *Json) ToXml(rootTag ...string) ([]byte, error) { - return gxml.Encode(j.ToMap(), rootTag...) + return gxml.Encode(j.Map(), rootTag...) } func (j *Json) ToXmlString(rootTag ...string) (string, error) { @@ -79,7 +79,7 @@ func (j *Json) ToXmlString(rootTag ...string) (string, error) { } func (j *Json) ToXmlIndent(rootTag ...string) ([]byte, error) { - return gxml.EncodeWithIndent(j.ToMap(), rootTag...) + return gxml.EncodeWithIndent(j.Map(), rootTag...) } func (j *Json) ToXmlIndentString(rootTag ...string) (string, error) { diff --git a/encoding/gjson/gjson_api_new_load.go b/encoding/gjson/gjson_api_new_load.go index bd6dcb504..780b3a16f 100644 --- a/encoding/gjson/gjson_api_new_load.go +++ b/encoding/gjson/gjson_api_new_load.go @@ -8,8 +8,8 @@ package gjson import ( "bytes" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "reflect" "github.com/gogf/gf/internal/json" @@ -264,7 +264,7 @@ func doLoadContentWithOptions(dataType string, data []byte, options Options) (*J return nil, err } default: - err = errors.New("unsupported type for loading") + err = gerror.New("unsupported type for loading") } if err != nil { return nil, err diff --git a/encoding/gjson/gjson_deprecated.go b/encoding/gjson/gjson_deprecated.go deleted file mode 100644 index 9283866ca..000000000 --- a/encoding/gjson/gjson_deprecated.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package gjson - -import "github.com/gogf/gf/util/gconv" - -// ToMap converts current Json object to map[string]interface{}. -// It returns nil if fails. -// Deprecated, use Map instead. -func (j *Json) ToMap() map[string]interface{} { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.Map(*(j.p)) -} - -// ToArray converts current Json object to []interface{}. -// It returns nil if fails. -// Deprecated, use Array instead. -func (j *Json) ToArray() []interface{} { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.Interfaces(*(j.p)) -} - -// ToStruct converts current Json object to specified object. -// The <pointer> should be a pointer type of *struct. -// Deprecated, use Struct instead. -func (j *Json) ToStruct(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.Struct(*(j.p), pointer, mapping...) -} - -// ToStructDeep converts current Json object to specified object recursively. -// The <pointer> should be a pointer type of *struct. -// Deprecated, use Struct instead. -func (j *Json) ToStructDeep(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.StructDeep(*(j.p), pointer, mapping...) -} - -// ToStructs converts current Json object to specified object slice. -// The <pointer> should be a pointer type of []struct/*struct. -// Deprecated, use Structs instead. -func (j *Json) ToStructs(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.Structs(*(j.p), pointer, mapping...) -} - -// ToStructsDeep converts current Json object to specified object slice recursively. -// The <pointer> should be a pointer type of []struct/*struct. -// Deprecated, use Structs instead. -func (j *Json) ToStructsDeep(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.StructsDeep(*(j.p), pointer, mapping...) -} - -// ToScan automatically calls Struct or Structs function according to the type of parameter -// <pointer> to implement the converting.. -// Deprecated, use Scan instead. -func (j *Json) ToScan(pointer interface{}, mapping ...map[string]string) error { - return gconv.Scan(*(j.p), pointer, mapping...) -} - -// ToScanDeep automatically calls StructDeep or StructsDeep function according to the type of -// parameter <pointer> to implement the converting.. -// Deprecated, use Scan instead. -func (j *Json) ToScanDeep(pointer interface{}, mapping ...map[string]string) error { - return gconv.ScanDeep(*(j.p), pointer, mapping...) -} - -// ToMapToMap converts current Json object to specified map variable. -// The parameter of <pointer> should be type of *map. -// Deprecated, use MapToMap instead. -func (j *Json) ToMapToMap(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.MapToMap(*(j.p), pointer, mapping...) -} - -// ToMapToMaps converts current Json object to specified map variable slice. -// The parameter of <pointer> should be type of []map/*map. -// Deprecated, use MapToMaps instead. -func (j *Json) ToMapToMaps(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.MapToMaps(*(j.p), pointer, mapping...) -} - -// ToMapToMapsDeep converts current Json object to specified map variable slice recursively. -// The parameter of <pointer> should be type of []map/*map. -// Deprecated, use MapToMaps instead. -func (j *Json) ToMapToMapsDeep(pointer interface{}, mapping ...map[string]string) error { - j.mu.RLock() - defer j.mu.RUnlock() - return gconv.MapToMapsDeep(*(j.p), pointer, mapping...) -} diff --git a/encoding/gjson/gjson_z_example_conversion_test.go b/encoding/gjson/gjson_z_example_conversion_test.go index 1a103f6a6..c55d9c409 100644 --- a/encoding/gjson/gjson_z_example_conversion_test.go +++ b/encoding/gjson/gjson_z_example_conversion_test.go @@ -100,7 +100,7 @@ func Example_conversionToStruct() { Array []string } users := new(Users) - if err := j.ToStruct(users); err != nil { + if err := j.Struct(users); err != nil { panic(err) } fmt.Printf(`%+v`, users) diff --git a/encoding/gjson/gjson_z_unit_basic_test.go b/encoding/gjson/gjson_z_unit_basic_test.go index c89b615ad..35b9538f0 100644 --- a/encoding/gjson/gjson_z_unit_basic_test.go +++ b/encoding/gjson/gjson_z_unit_basic_test.go @@ -353,7 +353,7 @@ func Test_Convert2(t *testing.T) { t.Assert(j.GetGTime("time").Format("Y-m-d"), "2019-06-12") t.Assert(j.GetDuration("time").String(), "0s") - err := j.ToStruct(&name) + err := j.Struct(&name) t.Assert(err, nil) t.Assert(name.Name, "gf") //j.Dump() @@ -369,7 +369,7 @@ func Test_Convert2(t *testing.T) { t.Assert(err, nil) j = gjson.New(`[1,2,3]`) - t.Assert(len(j.ToArray()), 3) + t.Assert(len(j.Array()), 3) }) } @@ -400,7 +400,7 @@ func Test_Basic(t *testing.T) { err = j.Remove("1") t.Assert(err, nil) t.Assert(j.Get("0"), 1) - t.Assert(len(j.ToArray()), 2) + t.Assert(len(j.Array()), 2) j = gjson.New(`[1,2,3]`) // If index 0 is delete, its next item will be at index 0. @@ -408,13 +408,13 @@ func Test_Basic(t *testing.T) { t.Assert(j.Remove("0"), nil) t.Assert(j.Remove("0"), nil) t.Assert(j.Get("0"), nil) - t.Assert(len(j.ToArray()), 0) + t.Assert(len(j.Array()), 0) j = gjson.New(`[1,2,3]`) err = j.Remove("3") t.Assert(err, nil) t.Assert(j.Get("0"), 1) - t.Assert(len(j.ToArray()), 3) + t.Assert(len(j.Array()), 3) j = gjson.New(`[1,2,3]`) err = j.Remove("0.3") diff --git a/encoding/gjson/gjson_z_unit_json_test.go b/encoding/gjson/gjson_z_unit_json_test.go index 8c2b50feb..27f5bd27c 100644 --- a/encoding/gjson/gjson_z_unit_json_test.go +++ b/encoding/gjson/gjson_z_unit_json_test.go @@ -61,7 +61,7 @@ func Test_MapAttributeConvert(t *testing.T) { Title map[string]interface{} }{} - err = j.ToStruct(&tx) + err = j.Struct(&tx) gtest.Assert(err, nil) t.Assert(tx.Title, g.Map{ "l1": "标签1", "l2": "标签2", @@ -76,7 +76,7 @@ func Test_MapAttributeConvert(t *testing.T) { Title map[string]string }{} - err = j.ToStruct(&tx) + err = j.Struct(&tx) gtest.Assert(err, nil) t.Assert(tx.Title, g.Map{ "l1": "标签1", "l2": "标签2", diff --git a/encoding/gparser/gparser_unit_basic_test.go b/encoding/gparser/gparser_unit_basic_test.go index 59beb1d69..bcdae3700 100644 --- a/encoding/gparser/gparser_unit_basic_test.go +++ b/encoding/gparser/gparser_unit_basic_test.go @@ -230,14 +230,14 @@ func Test_Convert(t *testing.T) { err := p.GetStruct("person", &name) t.Assert(err, nil) t.Assert(name.Name, "gf") - t.Assert(p.ToMap()["name"], "gf") - err = p.ToStruct(&name) + t.Assert(p.Map()["name"], "gf") + err = p.Struct(&name) t.Assert(err, nil) t.Assert(name.Name, "gf") //p.Dump() p = gparser.New(`[0,1,2]`) - t.Assert(p.ToArray()[0], 0) + t.Assert(p.Array()[0], 0) }) } diff --git a/errors/gerror/gerror.go b/errors/gerror/gerror.go index ce5bc0fc5..8b029b23f 100644 --- a/errors/gerror/gerror.go +++ b/errors/gerror/gerror.go @@ -4,7 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package errors provides simple functions to manipulate errors. +// Package gerror provides simple functions to manipulate errors. // // Very note that, this package is quite a base package, which should not import extra // packages except standard packages, to avoid cycle imports. @@ -237,7 +237,7 @@ func WrapCodeSkipf(code, skip int, err error, format string, args ...interface{} } } -// Cause returns the error code of current error. +// Code returns the error code of current error. // It returns -1 if it has no error code or it does not implements interface Code. func Code(err error) int { if err != nil { diff --git a/frame/gins/gins_database.go b/frame/gins/gins_database.go index 26e97b5fe..1206ee7d4 100644 --- a/frame/gins/gins_database.go +++ b/frame/gins/gins_database.go @@ -7,6 +7,7 @@ package gins import ( + "context" "fmt" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" @@ -92,11 +93,11 @@ func Database(name ...string) gdb.DB { } if len(cg) > 0 { if gdb.GetConfig(group) == nil { - intlog.Printf("add configuration for group: %s, %#v", g, cg) + intlog.Printf(context.TODO(), "add configuration for group: %s, %#v", g, cg) gdb.SetConfigGroup(g, cg) } else { - intlog.Printf("ignore configuration as it already exists for group: %s, %#v", g, cg) - intlog.Printf("%s, %#v", g, cg) + intlog.Printf(context.TODO(), "ignore configuration as it already exists for group: %s, %#v", g, cg) + intlog.Printf(context.TODO(), "%s, %#v", g, cg) } } } @@ -104,17 +105,17 @@ func Database(name ...string) gdb.DB { // which is the default group configuration. if node := parseDBConfigNode(configMap); node != nil { cg := gdb.ConfigGroup{} - if node.LinkInfo != "" || node.Host != "" { + if node.Link != "" || node.Host != "" { cg = append(cg, *node) } if len(cg) > 0 { if gdb.GetConfig(group) == nil { - intlog.Printf("add configuration for group: %s, %#v", gdb.DefaultGroupName, cg) + intlog.Printf(context.TODO(), "add configuration for group: %s, %#v", gdb.DefaultGroupName, cg) gdb.SetConfigGroup(gdb.DefaultGroupName, cg) } else { - intlog.Printf("ignore configuration as it already exists for group: %s, %#v", gdb.DefaultGroupName, cg) - intlog.Printf("%s, %#v", gdb.DefaultGroupName, cg) + intlog.Printf(context.TODO(), "ignore configuration as it already exists for group: %s, %#v", gdb.DefaultGroupName, cg) + intlog.Printf(context.TODO(), "%s, %#v", gdb.DefaultGroupName, cg) } } } @@ -156,15 +157,19 @@ func parseDBConfigNode(value interface{}) *gdb.ConfigNode { if err != nil { panic(err) } - if _, v := gutil.MapPossibleItemByKey(nodeMap, "link"); v != nil { - node.LinkInfo = gconv.String(v) + // To be compatible with old version. + if _, v := gutil.MapPossibleItemByKey(nodeMap, "LinkInfo"); v != nil { + node.Link = gconv.String(v) + } + if _, v := gutil.MapPossibleItemByKey(nodeMap, "Link"); v != nil { + node.Link = gconv.String(v) } // Parse link syntax. - if node.LinkInfo != "" && node.Type == "" { - match, _ := gregex.MatchString(`([a-z]+):(.+)`, node.LinkInfo) + if node.Link != "" && node.Type == "" { + match, _ := gregex.MatchString(`([a-z]+):(.+)`, node.Link) if len(match) == 3 { node.Type = gstr.Trim(match[1]) - node.LinkInfo = gstr.Trim(match[2]) + node.Link = gstr.Trim(match[2]) } } return node diff --git a/frame/gins/gins_server.go b/frame/gins/gins_server.go index 090835628..e73eadaa8 100644 --- a/frame/gins/gins_server.go +++ b/frame/gins/gins_server.go @@ -24,17 +24,30 @@ func Server(name ...interface{}) *ghttp.Server { s := ghttp.GetServer(name...) // To avoid file no found error while it's not necessary. if Config().Available() { - var m map[string]interface{} + var ( + serverConfigMap map[string]interface{} + serverLoggerConfigMap map[string]interface{} + ) nodeKey, _ := gutil.MapPossibleItemByKey(Config().GetMap("."), configNodeNameServer) if nodeKey == "" { nodeKey = configNodeNameServer } - m = Config().GetMap(fmt.Sprintf(`%s.%s`, nodeKey, s.GetName())) - if len(m) == 0 { - m = Config().GetMap(nodeKey) + // Server configuration. + serverConfigMap = Config().GetMap(fmt.Sprintf(`%s.%s`, nodeKey, s.GetName())) + if len(serverConfigMap) == 0 { + serverConfigMap = Config().GetMap(nodeKey) } - if len(m) > 0 { - if err := s.SetConfigWithMap(m); err != nil { + if len(serverConfigMap) > 0 { + if err := s.SetConfigWithMap(serverConfigMap); err != nil { + panic(err) + } + } + // Server logger configuration. + serverLoggerConfigMap = Config().GetMap( + fmt.Sprintf(`%s.%s.%s`, nodeKey, s.GetName(), configNodeNameLogger), + ) + if len(serverLoggerConfigMap) > 0 { + if err := s.Logger().SetConfigWithMap(serverLoggerConfigMap); err != nil { panic(err) } } diff --git a/go.mod b/go.mod index ba8172629..0ccef1c24 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.14 require ( github.com/BurntSushi/toml v0.3.1 github.com/clbanning/mxj v1.8.5-0.20200714211355-ff02cfb8ea28 - github.com/fatih/color v1.12.0 // indirect + github.com/fatih/color v1.12.0 github.com/fsnotify/fsnotify v1.4.9 github.com/gogf/mysql v1.6.1-0.20210603073548-16164ae25579 github.com/gomodule/redigo v2.0.0+incompatible diff --git a/go.sum b/go.sum index 582d62e0e..bd012e8c6 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/clbanning/mxj v1.8.5-0.20200714211355-ff02cfb8ea28 h1:LdXxtjzvZYhhUao github.com/clbanning/mxj v1.8.5-0.20200714211355-ff02cfb8ea28/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.12.0 h1:mRhaKNwANqRgUBGKmnI5ZxEk7QXmjQeCcuYFMX2bfcc= +github.com/fatih/color v1.12.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gogf/mysql v1.6.1-0.20210603073548-16164ae25579 h1:pP/uEy52biKDytlgK/ug8kiYPAiYu6KajKVUHfGrtyw= @@ -16,6 +18,10 @@ github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvK github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grokify/html-strip-tags-go v0.0.0-20190921062105-daaa06bf1aaf h1:wIOAyJMMen0ELGiFzlmqxdcV1yGbkyHBAB6PolcNbLA= github.com/grokify/html-strip-tags-go v0.0.0-20190921062105-daaa06bf1aaf/go.mod h1:2Su6romC5/1VXOQMaWL2yb618ARB8iVo6/DR99A6d78= +github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.10 h1:CoZ3S2P7pvtP45xOtBw+/mDL2z0RKI576gSkzRRpdGg= github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= @@ -44,6 +50,8 @@ golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/i18n/gi18n/gi18n_manager.go b/i18n/gi18n/gi18n_manager.go index 0949c1f44..cc223e32d 100644 --- a/i18n/gi18n/gi18n_manager.go +++ b/i18n/gi18n/gi18n_manager.go @@ -8,8 +8,8 @@ package gi18n import ( "context" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "strings" "sync" @@ -70,7 +70,7 @@ func New(options ...Options) *Manager { gregex.Quote(opts.Delimiters[1]), ), } - intlog.Printf(`New: %#v`, m) + intlog.Printf(context.TODO(), `New: %#v`, m) return m } @@ -101,24 +101,24 @@ func (m *Manager) SetPath(path string) error { } else { realPath, _ := gfile.Search(path) if realPath == "" { - return errors.New(fmt.Sprintf(`%s does not exist`, path)) + return gerror.Newf(`%s does not exist`, path) } m.options.Path = realPath } - intlog.Printf(`SetPath: %s`, m.options.Path) + intlog.Printf(context.TODO(), `SetPath: %s`, m.options.Path) return nil } // SetLanguage sets the language for translator. func (m *Manager) SetLanguage(language string) { m.options.Language = language - intlog.Printf(`SetLanguage: %s`, m.options.Language) + intlog.Printf(context.TODO(), `SetLanguage: %s`, m.options.Language) } // SetDelimiters sets the delimiters for translator. func (m *Manager) SetDelimiters(left, right string) { m.pattern = fmt.Sprintf(`%s(\w+)%s`, gregex.Quote(left), gregex.Quote(right)) - intlog.Printf(`SetDelimiters: %v`, m.pattern) + intlog.Printf(context.TODO(), `SetDelimiters: %v`, m.pattern) } // T is alias of Translate for convenience. @@ -139,7 +139,7 @@ func (m *Manager) TranslateFormat(ctx context.Context, format string, values ... // Translate translates <content> with configured language. func (m *Manager) Translate(ctx context.Context, content string) string { - m.init() + m.init(ctx) m.mu.RLock() defer m.mu.RUnlock() transLang := m.options.Language @@ -163,14 +163,14 @@ func (m *Manager) Translate(ctx context.Context, content string) string { } return match[0] }) - intlog.Printf(`Translate for language: %s`, transLang) + intlog.Printf(ctx, `Translate for language: %s`, transLang) return result } // GetContent retrieves and returns the configured content for given key and specified language. // It returns an empty string if not found. func (m *Manager) GetContent(ctx context.Context, key string) string { - m.init() + m.init(ctx) m.mu.RLock() defer m.mu.RUnlock() transLang := m.options.Language @@ -185,7 +185,7 @@ func (m *Manager) GetContent(ctx context.Context, key string) string { // init initializes the manager for lazy initialization design. // The i18n manager is only initialized once. -func (m *Manager) init() { +func (m *Manager) init(ctx context.Context) { m.mu.RLock() // If the data is not nil, means it's already initialized. if m.data != nil { @@ -223,17 +223,13 @@ func (m *Manager) init() { m.data[lang][k] = gconv.String(v) } } else { - intlog.Errorf("load i18n file '%s' failed: %v", name, err) + intlog.Errorf(ctx, "load i18n file '%s' failed: %v", name, err) } } } } else if m.options.Path != "" { files, _ := gfile.ScanDirFile(m.options.Path, "*.*", true) if len(files) == 0 { - //intlog.Printf( - // "no i18n files found in configured directory: %s", - // m.options.Path, - //) return } var ( @@ -258,7 +254,7 @@ func (m *Manager) init() { m.data[lang][k] = gconv.String(v) } } else { - intlog.Errorf("load i18n file '%s' failed: %v", file, err) + intlog.Errorf(ctx, "load i18n file '%s' failed: %v", file, err) } } // Monitor changes of i18n files for hot reload feature. diff --git a/internal/intlog/intlog.go b/internal/intlog/intlog.go index f2675ce91..d5ca85d6a 100644 --- a/internal/intlog/intlog.go +++ b/internal/intlog/intlog.go @@ -8,9 +8,12 @@ package intlog import ( + "bytes" + "context" "fmt" "github.com/gogf/gf/debug/gdebug" "github.com/gogf/gf/internal/utils" + "go.opentelemetry.io/otel/trace" "path/filepath" "time" ) @@ -39,42 +42,56 @@ func SetEnabled(enabled bool) { // Print prints `v` with newline using fmt.Println. // The parameter `v` can be multiple variables. -func Print(v ...interface{}) { - if !isGFDebug { - return - } - fmt.Println(append([]interface{}{now(), "[INTE]", file()}, v...)...) +func Print(ctx context.Context, v ...interface{}) { + doPrint(ctx, fmt.Sprint(v...), false) } // Printf prints `v` with format `format` using fmt.Printf. // The parameter `v` can be multiple variables. -func Printf(format string, v ...interface{}) { - if !isGFDebug { - return - } - fmt.Printf(now()+" [INTE] "+file()+" "+format+"\n", v...) +func Printf(ctx context.Context, format string, v ...interface{}) { + doPrint(ctx, fmt.Sprintf(format, v...), false) } // Error prints `v` with newline using fmt.Println. // The parameter `v` can be multiple variables. -func Error(v ...interface{}) { - if !isGFDebug { - return - } - array := append([]interface{}{now(), "[INTE]", file()}, v...) - array = append(array, "\n"+gdebug.StackWithFilter(stackFilterKey)) - fmt.Println(array...) +func Error(ctx context.Context, v ...interface{}) { + doPrint(ctx, fmt.Sprint(v...), true) } // Errorf prints `v` with format `format` using fmt.Printf. -func Errorf(format string, v ...interface{}) { +func Errorf(ctx context.Context, format string, v ...interface{}) { + doPrint(ctx, fmt.Sprintf(format, v...), true) +} + +func doPrint(ctx context.Context, content string, stack bool) { if !isGFDebug { return } - fmt.Printf( - now()+" [INTE] "+file()+" "+format+"\n%s\n", - append(v, gdebug.StackWithFilter(stackFilterKey))..., - ) + buffer := bytes.NewBuffer(nil) + buffer.WriteString(now()) + buffer.WriteString(" [INTE] ") + buffer.WriteString(file()) + if s := traceIdStr(ctx); s != "" { + buffer.WriteString(" " + s) + } + buffer.WriteString(content) + buffer.WriteString("\n") + if stack { + buffer.WriteString(gdebug.StackWithFilter(stackFilterKey)) + } + fmt.Print(buffer.String()) +} + +// traceIdStr retrieves and returns the trace id string for logging output. +func traceIdStr(ctx context.Context) string { + if ctx == nil { + return "" + } + spanCtx := trace.SpanContextFromContext(ctx) + if traceId := spanCtx.TraceID(); traceId.IsValid() { + return "{" + traceId.String() + "}" + } + return "" } // now returns current time string. diff --git a/internal/utils/utils_debug.go b/internal/utils/utils_debug.go index c03d30dd4..7817aecf9 100644 --- a/internal/utils/utils_debug.go +++ b/internal/utils/utils_debug.go @@ -11,8 +11,8 @@ import ( ) const ( - debugKey = "gf.debug" // Debug key for checking if in debug mode. - StackFilterKeyForGoFrame = "/github.com/gogf/gf/" // Stack filtering key for all GoFrame module paths. + commandEnvKeyForDebugKey = "gf.debug" // Debug key for checking if in debug mode. + StackFilterKeyForGoFrame = "github.com/gogf/gf@" // Stack filtering key for all GoFrame module paths. ) var ( @@ -22,7 +22,7 @@ var ( func init() { // Debugging configured. - value := command.GetOptWithEnv(debugKey) + value := command.GetOptWithEnv(commandEnvKeyForDebugKey) if value == "" || value == "0" || value == "false" { isDebugEnabled = false } else { diff --git a/net/ghttp/ghttp.go b/net/ghttp/ghttp.go index ec486b0ef..54e2a4bbd 100644 --- a/net/ghttp/ghttp.go +++ b/net/ghttp/ghttp.go @@ -19,11 +19,11 @@ import ( ) type ( - // Server wraps the http.Server and provides more feature. + // Server wraps the http.Server and provides more rich features. Server struct { name string // Unique name for instance management. config ServerConfig // Configuration. - plugins []Plugin // Plugin array. + plugins []Plugin // Plugin array to extends server functionality. servers []*gracefulServer // Underlying http.Server array. serverCount *gtype.Int // Underlying http.Server count. closeChan chan struct{} // Used for underlying server closing event notification. @@ -44,7 +44,7 @@ type ( Priority int // Just for reference. } - // Router item just for route dumps. + // RouterItem is just for route dumps. RouterItem struct { Server string // Server name. Address string // Listening address. @@ -98,7 +98,7 @@ type ( Stack() string } - // Request handler function. + // HandlerFunc is request handler function. HandlerFunc = func(r *Request) // Listening file descriptor mapping. @@ -107,10 +107,6 @@ type ( ) const ( - HOOK_BEFORE_SERVE = "HOOK_BEFORE_SERVE" // Deprecated, use HookBeforeServe instead. - HOOK_AFTER_SERVE = "HOOK_AFTER_SERVE" // Deprecated, use HookAfterServe instead. - HOOK_BEFORE_OUTPUT = "HOOK_BEFORE_OUTPUT" // Deprecated, use HookBeforeOutput instead. - HOOK_AFTER_OUTPUT = "HOOK_AFTER_OUTPUT" // Deprecated, use HookAfterOutput instead. HookBeforeServe = "HOOK_BEFORE_SERVE" HookAfterServe = "HOOK_AFTER_SERVE" HookBeforeOutput = "HOOK_BEFORE_OUTPUT" diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index a07e115b5..fa86d02c5 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -73,7 +73,10 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request { EnterTime: gtime.TimestampMilli(), } request.Cookie = GetCookie(request) - request.Session = s.sessionManager.New(request.GetSessionId()) + request.Session = s.sessionManager.New( + r.Context(), + request.GetSessionId(), + ) request.Response.Request = request request.Middleware = &middleware{ request: request, @@ -84,7 +87,7 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request { address = request.RemoteAddr header = fmt.Sprintf("%v", request.Header) ) - intlog.Print(address, header) + intlog.Print(r.Context(), address, header) return guid.S([]byte(address), []byte(header)) }) if err != nil { diff --git a/net/ghttp/ghttp_request_param_file.go b/net/ghttp/ghttp_request_param_file.go index c5b803221..997b82ed4 100644 --- a/net/ghttp/ghttp_request_param_file.go +++ b/net/ghttp/ghttp_request_param_file.go @@ -7,7 +7,8 @@ package ghttp import ( - "errors" + "context" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/os/gtime" @@ -21,6 +22,7 @@ import ( // UploadFile wraps the multipart uploading file with more and convenient features. type UploadFile struct { *multipart.FileHeader + ctx context.Context } // UploadFiles is array type for *UploadFile. @@ -33,14 +35,14 @@ type UploadFiles []*UploadFile // Note that it will OVERWRITE the target file if there's already a same name file exist. func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename string, err error) { if f == nil { - return "", errors.New("file is empty, maybe you retrieve it from invalid field name or form enctype") + return "", gerror.New("file is empty, maybe you retrieve it from invalid field name or form enctype") } if !gfile.Exists(dirPath) { if err = gfile.Mkdir(dirPath); err != nil { return } } else if !gfile.IsDir(dirPath) { - return "", errors.New(`parameter "dirPath" should be a directory path`) + return "", gerror.New(`parameter "dirPath" should be a directory path`) } file, err := f.Open() @@ -60,7 +62,7 @@ func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename stri return "", err } defer newFile.Close() - intlog.Printf(`save upload file: %s`, filePath) + intlog.Printf(f.ctx, `save upload file: %s`, filePath) if _, err := io.Copy(newFile, file); err != nil { return "", err } @@ -74,7 +76,7 @@ func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename stri // The parameter <randomlyRename> specifies whether randomly renames all the file names. func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames []string, err error) { if len(fs) == 0 { - return nil, errors.New("file array is empty, maybe you retrieve it from invalid field name or form enctype") + return nil, gerror.New("file array is empty, maybe you retrieve it from invalid field name or form enctype") } for _, f := range fs { if filename, err := f.Save(dirPath, randomlyRename...); err != nil { @@ -114,6 +116,7 @@ func (r *Request) GetUploadFiles(name string) UploadFiles { uploadFiles := make(UploadFiles, len(multipartFiles)) for k, v := range multipartFiles { uploadFiles[k] = &UploadFile{ + ctx: r.Context(), FileHeader: v, } } diff --git a/net/ghttp/ghttp_request_param_page.go b/net/ghttp/ghttp_request_param_page.go index 478ae3ed0..20e0c8f55 100644 --- a/net/ghttp/ghttp_request_param_page.go +++ b/net/ghttp/ghttp_request_param_page.go @@ -14,7 +14,7 @@ import ( ) // GetPage creates and returns the pagination object for given <totalSize> and <pageSize>. -// NOTE THAT the page parameter name from client is constantly defined as gpage.PAGE_NAME +// NOTE THAT the page parameter name from client is constantly defined as gpage.DefaultPageName // for simplification and convenience. func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { // It must has Router object attribute. @@ -27,7 +27,7 @@ func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { // Check the page variable in the URI. if len(r.Router.RegNames) > 0 { for _, name := range r.Router.RegNames { - if name == gpage.PAGE_NAME { + if name == gpage.DefaultPageName { uriHasPageName = true break } @@ -38,8 +38,8 @@ func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { urlTemplate = r.Router.Uri for i, name := range r.Router.RegNames { rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name) - if name == gpage.PAGE_NAME { - urlTemplate, _ = gregex.ReplaceString(rule, gpage.PAGE_PLACE_HOLDER, urlTemplate) + if name == gpage.DefaultPageName { + urlTemplate, _ = gregex.ReplaceString(rule, gpage.DefaultPagePlaceHolder, urlTemplate) } else { urlTemplate, _ = gregex.ReplaceString(rule, match[i+1], urlTemplate) } @@ -51,7 +51,7 @@ func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { // Check the page variable in the query string. if !uriHasPageName { values := url.Query() - values.Set(gpage.PAGE_NAME, gpage.PAGE_PLACE_HOLDER) + values.Set(gpage.DefaultPageName, gpage.DefaultPagePlaceHolder) url.RawQuery = values.Encode() // Replace the encoded "{.page}" to original "{.page}". url.RawQuery = gstr.Replace(url.RawQuery, "%7B.page%7D", "{.page}") @@ -60,5 +60,5 @@ func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { urlTemplate += "?" + url.RawQuery } - return gpage.New(totalSize, pageSize, r.GetInt(gpage.PAGE_NAME), urlTemplate) + return gpage.New(totalSize, pageSize, r.GetInt(gpage.DefaultPageName), urlTemplate) } diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index 5896ace8c..ab7738e38 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -8,6 +8,7 @@ package ghttp import ( "bytes" + "context" "github.com/gogf/gf/debug/gdebug" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" @@ -70,10 +71,10 @@ func serverProcessInit() { // Process message handler. // It's enabled only graceful feature is enabled. if gracefulEnabled { - intlog.Printf("%d: graceful reload feature is enabled", gproc.Pid()) + intlog.Printf(context.TODO(), "%d: graceful reload feature is enabled", gproc.Pid()) go handleProcessMessage() } else { - intlog.Printf("%d: graceful reload feature is disabled", gproc.Pid()) + intlog.Printf(context.TODO(), "%d: graceful reload feature is disabled", gproc.Pid()) } // It's an ugly calling for better initializing the main package path @@ -195,7 +196,7 @@ func (s *Server) Start() error { if gproc.IsChild() { gtimer.SetTimeout(time.Duration(s.config.GracefulTimeout)*time.Second, func() { if err := gproc.Send(gproc.PPid(), []byte("exit"), adminGProcCommGroup); err != nil { - //glog.Error("server error in process communication:", err) + intlog.Error(context.TODO(), "server error in process communication:", err) } }) } @@ -315,9 +316,9 @@ func (s *Server) Run() { // Remove plugins. if len(s.plugins) > 0 { for _, p := range s.plugins { - intlog.Printf(`remove plugin: %s`, p.Name()) + intlog.Printf(context.TODO(), `remove plugin: %s`, p.Name()) if err := p.Remove(); err != nil { - intlog.Errorf("%+v", err) + intlog.Errorf(context.TODO(), "%+v", err) } } } @@ -333,7 +334,7 @@ func Wait() { s := v.(*Server) if len(s.plugins) > 0 { for _, p := range s.plugins { - intlog.Printf(`remove plugin: %s`, p.Name()) + intlog.Printf(context.TODO(), `remove plugin: %s`, p.Name()) p.Remove() } } diff --git a/net/ghttp/ghttp_server_admin_process.go b/net/ghttp/ghttp_server_admin_process.go index 577a0168d..e6fa2e622 100644 --- a/net/ghttp/ghttp_server_admin_process.go +++ b/net/ghttp/ghttp_server_admin_process.go @@ -8,8 +8,9 @@ package ghttp import ( "bytes" - "errors" + "context" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" "os" @@ -51,7 +52,7 @@ var serverProcessStatus = gtype.NewInt() // The optional parameter <newExeFilePath> specifies the new binary file for creating process. func RestartAllServer(newExeFilePath ...string) error { if !gracefulEnabled { - return errors.New("graceful reload feature is disabled") + return gerror.New("graceful reload feature is disabled") } serverActionLocker.Lock() defer serverActionLocker.Unlock() @@ -84,9 +85,9 @@ func checkProcessStatus() error { if status > 0 { switch status { case adminActionRestarting: - return errors.New("server is restarting") + return gerror.New("server is restarting") case adminActionShuttingDown: - return errors.New("server is shutting down") + return gerror.New("server is shutting down") } } return nil @@ -97,7 +98,7 @@ func checkProcessStatus() error { func checkActionFrequency() error { interval := gtime.TimestampMilli() - serverActionLastTime.Val() if interval < adminActionIntervalLimit { - return errors.New(fmt.Sprintf("too frequent action, please retry in %d ms", adminActionIntervalLimit-interval)) + return gerror.Newf("too frequent action, please retry in %d ms", adminActionIntervalLimit-interval) } serverActionLastTime.Set(gtime.TimestampMilli()) return nil @@ -173,7 +174,7 @@ func bufferToServerFdMap(buffer []byte) map[string]listenerFdMap { sfm := make(map[string]listenerFdMap) if len(buffer) > 0 { j, _ := gjson.LoadContent(buffer) - for k, _ := range j.ToMap() { + for k, _ := range j.Map() { m := make(map[string]string) for k, v := range j.GetMap(k) { m[k] = gconv.String(v) @@ -266,10 +267,10 @@ func handleProcessMessage() { for { if msg := gproc.Receive(adminGProcCommGroup); msg != nil { if bytes.EqualFold(msg.Data, []byte("exit")) { - intlog.Printf("%d: process message: exit", gproc.Pid()) + intlog.Printf(context.TODO(), "%d: process message: exit", gproc.Pid()) shutdownWebServersGracefully() allDoneChan <- struct{}{} - intlog.Printf("%d: process message: exit done", gproc.Pid()) + intlog.Printf(context.TODO(), "%d: process message: exit done", gproc.Pid()) return } } diff --git a/net/ghttp/ghttp_server_admin_unix.go b/net/ghttp/ghttp_server_admin_unix.go index b4f0cf6b7..3e5591473 100644 --- a/net/ghttp/ghttp_server_admin_unix.go +++ b/net/ghttp/ghttp_server_admin_unix.go @@ -9,6 +9,7 @@ package ghttp import ( + "context" "github.com/gogf/gf/internal/intlog" "os" "os/signal" @@ -33,7 +34,7 @@ func handleProcessSignal() { ) for { sig = <-procSignalChan - intlog.Printf(`signal received: %s`, sig.String()) + intlog.Printf(context.TODO(), `signal received: %s`, sig.String()) switch sig { // Shutdown the servers. case syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGABRT: @@ -49,7 +50,7 @@ func handleProcessSignal() { // Restart the servers. case syscall.SIGUSR1: if err := restartWebServers(sig.String()); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } return diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index e7b19c926..5021c26d8 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -7,6 +7,7 @@ package ghttp import ( + "context" "crypto/tls" "fmt" "github.com/gogf/gf/internal/intlog" @@ -223,13 +224,14 @@ type ServerConfig struct { GracefulTimeout uint8 `json:"gracefulTimeout"` } +// Config creates and returns a ServerConfig object with default configurations. // Deprecated. Use NewConfig instead. func Config() ServerConfig { return NewConfig() } // NewConfig creates and returns a ServerConfig object with default configurations. -// Note that, do not define this default configuration to local package variable, as there're +// Note that, do not define this default configuration to local package variable, as there are // some pointer attributes that may be shared in different servers. func NewConfig() ServerConfig { return ServerConfig{ @@ -331,10 +333,12 @@ func (s *Server) SetConfig(c ServerConfig) error { return err } } - s.config.Logger.SetLevelStr(s.config.LogLevel) + if err := s.config.Logger.SetLevelStr(s.config.LogLevel); err != nil { + intlog.Error(context.TODO(), err) + } SetGraceful(c.Graceful) - intlog.Printf("SetConfig: %+v", s.config) + intlog.Printf(context.TODO(), "SetConfig: %+v", s.config) return nil } diff --git a/net/ghttp/ghttp_server_config_logging.go b/net/ghttp/ghttp_server_config_logging.go index 3ea28f0de..471ea876c 100644 --- a/net/ghttp/ghttp_server_config_logging.go +++ b/net/ghttp/ghttp_server_config_logging.go @@ -6,6 +6,8 @@ package ghttp +import "github.com/gogf/gf/os/glog" + // SetLogPath sets the log path for server. // It logs content to file only if the log path is set. func (s *Server) SetLogPath(path string) error { @@ -23,6 +25,17 @@ func (s *Server) SetLogPath(path string) error { return nil } +// SetLogger sets the logger for logging responsibility. +// Note that it cannot be set in runtime as there may be concurrent safety issue. +func (s *Server) SetLogger(logger *glog.Logger) { + s.config.Logger = logger +} + +// Logger is alias of GetLogger. +func (s *Server) Logger() *glog.Logger { + return s.config.Logger +} + // SetLogLevel sets logging level by level string. func (s *Server) SetLogLevel(level string) { s.config.LogLevel = level diff --git a/net/ghttp/ghttp_server_graceful.go b/net/ghttp/ghttp_server_graceful.go index e4ca2553c..2e6f82390 100644 --- a/net/ghttp/ghttp_server_graceful.go +++ b/net/ghttp/ghttp_server_graceful.go @@ -9,8 +9,8 @@ package ghttp import ( "context" "crypto/tls" - "errors" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gproc" "github.com/gogf/gf/os/gres" "github.com/gogf/gf/text/gstr" @@ -122,7 +122,7 @@ func (s *gracefulServer) ListenAndServeTLS(certFile, keyFile string, tlsConfig . } if err != nil { - return errors.New(fmt.Sprintf(`open cert file "%s","%s" failed: %s`, certFile, keyFile, err.Error())) + return gerror.Newf(`open cert file "%s","%s" failed: %s`, certFile, keyFile, err.Error()) } ln, err := s.getNetListener() if err != nil { diff --git a/net/ghttp/ghttp_server_handler.go b/net/ghttp/ghttp_server_handler.go index bc2bb05d9..64bdf0ee1 100644 --- a/net/ghttp/ghttp_server_handler.go +++ b/net/ghttp/ghttp_server_handler.go @@ -7,6 +7,7 @@ package ghttp import ( + "github.com/gogf/gf/internal/intlog" "net/http" "os" "sort" @@ -80,9 +81,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Close the request and response body // to release the file descriptor in time. - _ = request.Request.Body.Close() + err := request.Request.Body.Close() + if err != nil { + intlog.Error(request.Context(), err) + } if request.Request.Response != nil { - _ = request.Request.Response.Body.Close() + err = request.Request.Response.Body.Close() + if err != nil { + intlog.Error(request.Context(), err) + } } }() @@ -188,9 +195,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // searchStaticFile searches the file with given URI. // It returns a file struct specifying the file information. func (s *Server) searchStaticFile(uri string) *staticFile { - var file *gres.File - var path string - var dir bool + var ( + file *gres.File + path string + dir bool + ) // Firstly search the StaticPaths mapping. if len(s.config.StaticPaths) > 0 { for _, item := range s.config.StaticPaths { diff --git a/net/ghttp/ghttp_server_log.go b/net/ghttp/ghttp_server_log.go index c8f0f5f8f..24ce77bcc 100644 --- a/net/ghttp/ghttp_server_log.go +++ b/net/ghttp/ghttp_server_log.go @@ -9,14 +9,8 @@ package ghttp import ( "fmt" "github.com/gogf/gf/errors/gerror" - "github.com/gogf/gf/os/glog" ) -// Logger returns the logger of the server. -func (s *Server) Logger() *glog.Logger { - return s.config.Logger -} - // handleAccessLog handles the access logging for server. func (s *Server) handleAccessLog(r *Request) { if !s.IsAccessLogEnabled() { @@ -26,7 +20,7 @@ func (s *Server) handleAccessLog(r *Request) { if r.TLS != nil { scheme = "https" } - s.Logger().File(s.config.AccessLogPattern). + s.Logger().Ctx(r.Context()).File(s.config.AccessLogPattern). Stdout(s.config.LogStdout). Printf( `%d "%s %s %s %s %s" %.3f, %s, "%s", "%s"`, @@ -63,7 +57,7 @@ func (s *Server) handleErrorLog(err error, r *Request) { } else { content += ", " + err.Error() } - s.config.Logger. + s.Logger().Ctx(r.Context()). File(s.config.ErrorLogPattern). Stdout(s.config.LogStdout). Print(content) diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index cc7d883b3..6d31d5e78 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -7,9 +7,9 @@ package ghttp import ( - "errors" "fmt" "github.com/gogf/gf/container/gtype" + "github.com/gogf/gf/errors/gerror" "strings" "github.com/gogf/gf/debug/gdebug" @@ -53,7 +53,7 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err } } if path == "" { - err = errors.New("invalid pattern: URI should not be empty") + err = gerror.New("invalid pattern: URI should not be empty") } if path != "/" { path = strings.TrimRight(path, "/") diff --git a/net/ghttp/ghttp_unit_request_test.go b/net/ghttp/ghttp_unit_request_test.go index 26880df6b..96ba9a375 100644 --- a/net/ghttp/ghttp_unit_request_test.go +++ b/net/ghttp/ghttp_unit_request_test.go @@ -621,3 +621,89 @@ func Test_Params_Parse_Validation(t *testing.T) { t.Assert(client.GetContent("/parse?name=john11&password1=123456&password2=123456"), `ok`) }) } + +func Test_Params_Parse_EmbeddedWithAliasName1(t *testing.T) { + // 获取内容列表 + type ContentGetListInput struct { + Type string + CategoryId uint + Page int + Size int + Sort int + UserId uint + } + // 获取内容列表 + type ContentGetListReq struct { + ContentGetListInput + CategoryId uint `p:"cate"` + Page int `d:"1" v:"min:0#分页号码错误"` + Size int `d:"10" v:"max:50#分页数量最大50条"` + } + + p, _ := ports.PopRand() + s := g.Server(p) + s.BindHandler("/parse", func(r *ghttp.Request) { + var req *ContentGetListReq + if err := r.Parse(&req); err != nil { + r.Response.Write(err) + } else { + r.Response.Write(req.ContentGetListInput) + } + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + prefix := fmt.Sprintf("http://127.0.0.1:%d", p) + client := g.Client() + client.SetPrefix(prefix) + + t.Assert(client.GetContent("/parse?cate=1&page=2&size=10"), `{"Type":"","CategoryId":0,"Page":2,"Size":10,"Sort":0,"UserId":0}`) + }) +} + +func Test_Params_Parse_EmbeddedWithAliasName2(t *testing.T) { + // 获取内容列表 + type ContentGetListInput struct { + Type string + CategoryId uint `p:"cate"` + Page int + Size int + Sort int + UserId uint + } + // 获取内容列表 + type ContentGetListReq struct { + ContentGetListInput + CategoryId uint `p:"cate"` + Page int `d:"1" v:"min:0#分页号码错误"` + Size int `d:"10" v:"max:50#分页数量最大50条"` + } + + p, _ := ports.PopRand() + s := g.Server(p) + s.BindHandler("/parse", func(r *ghttp.Request) { + var req *ContentGetListReq + if err := r.Parse(&req); err != nil { + r.Response.Write(err) + } else { + r.Response.Write(req.ContentGetListInput) + } + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + prefix := fmt.Sprintf("http://127.0.0.1:%d", p) + client := g.Client() + client.SetPrefix(prefix) + + t.Assert(client.GetContent("/parse?cate=1&page=2&size=10"), `{"Type":"","CategoryId":1,"Page":2,"Size":10,"Sort":0,"UserId":0}`) + }) +} diff --git a/net/ghttp/ghttp_unit_router_domain_basic_test.go b/net/ghttp/ghttp_unit_router_domain_basic_test.go index 5b597cb6f..3db0f35cc 100644 --- a/net/ghttp/ghttp_unit_router_domain_basic_test.go +++ b/net/ghttp/ghttp_unit_router_domain_basic_test.go @@ -335,15 +335,15 @@ func Test_Router_DomainGroup(t *testing.T) { d.Group("/", func(group *ghttp.RouterGroup) { group.Group("/app", func(gApp *ghttp.RouterGroup) { gApp.GET("/{table}/list/{page}.html", func(r *ghttp.Request) { - intlog.Print("/{table}/list/{page}.html") + intlog.Print(r.Context(), "/{table}/list/{page}.html") r.Response.Write(r.Get("table"), "&", r.Get("page")) }) gApp.GET("/order/info/{order_id}", func(r *ghttp.Request) { - intlog.Print("/order/info/{order_id}") + intlog.Print(r.Context(), "/order/info/{order_id}") r.Response.Write(r.Get("order_id")) }) gApp.DELETE("/comment/{id}", func(r *ghttp.Request) { - intlog.Print("/comment/{id}") + intlog.Print(r.Context(), "/comment/{id}") r.Response.Write(r.Get("id")) }) }) diff --git a/net/ghttp/internal/client/client_request.go b/net/ghttp/internal/client/client_request.go index 545e8e7b9..90b8970b9 100644 --- a/net/ghttp/internal/client/client_request.go +++ b/net/ghttp/internal/client/client_request.go @@ -9,8 +9,7 @@ package client import ( "bytes" "context" - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/internal/utils" @@ -189,18 +188,18 @@ func (c *Client) prepareRequest(method, url string, data ...interface{}) (req *h if len(array[1]) > 6 && strings.Compare(array[1][0:6], "@file:") == 0 { path := array[1][6:] if !gfile.Exists(path) { - return nil, errors.New(fmt.Sprintf(`"%s" does not exist`, path)) + return nil, gerror.Newf(`"%s" does not exist`, path) } if file, err := writer.CreateFormFile(array[0], gfile.Basename(path)); err == nil { if f, err := os.Open(path); err == nil { if _, err = io.Copy(file, f); err != nil { if err := f.Close(); err != nil { - intlog.Errorf(`%+v`, err) + intlog.Errorf(c.ctx, `%+v`, err) } return nil, err } if err := f.Close(); err != nil { - intlog.Errorf(`%+v`, err) + intlog.Errorf(c.ctx, `%+v`, err) } } else { return nil, err @@ -303,7 +302,7 @@ func (c *Client) callRequest(req *http.Request) (resp *Response, err error) { // The response might not be nil when err != nil. if resp.Response != nil { if err := resp.Response.Body.Close(); err != nil { - intlog.Errorf(`%+v`, err) + intlog.Errorf(c.ctx, `%+v`, err) } } if c.retryCount > 0 { diff --git a/net/gipv4/gipv4_ip.go b/net/gipv4/gipv4_ip.go index 46d2aab8e..29c113d23 100644 --- a/net/gipv4/gipv4_ip.go +++ b/net/gipv4/gipv4_ip.go @@ -8,7 +8,7 @@ package gipv4 import ( - "errors" + "github.com/gogf/gf/errors/gerror" "net" "strconv" "strings" @@ -38,7 +38,7 @@ func GetIntranetIp() (ip string, err error) { return "", err } if len(ips) == 0 { - return "", errors.New("no intranet ip found") + return "", gerror.New("no intranet ip found") } return ips[0], nil } diff --git a/net/gtcp/gtcp_server.go b/net/gtcp/gtcp_server.go index f4ec7363c..d8c4f22fd 100644 --- a/net/gtcp/gtcp_server.go +++ b/net/gtcp/gtcp_server.go @@ -8,7 +8,7 @@ package gtcp import ( "crypto/tls" - "errors" + "github.com/gogf/gf/errors/gerror" "net" "sync" @@ -116,7 +116,7 @@ func (s *Server) Close() error { // Run starts running the TCP Server. func (s *Server) Run() (err error) { if s.handler == nil { - err = errors.New("start running failed: socket handler not defined") + err = gerror.New("start running failed: socket handler not defined") glog.Error(err) return } diff --git a/net/gtrace/gtrace.go b/net/gtrace/gtrace.go index 34d906cee..6d847474b 100644 --- a/net/gtrace/gtrace.go +++ b/net/gtrace/gtrace.go @@ -9,7 +9,6 @@ package gtrace import ( "context" - "fmt" "github.com/gogf/gf/container/gmap" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/net/gipv4" @@ -23,15 +22,17 @@ import ( ) const ( - tracingCommonKeyIpIntranet = `ip.intranet` - tracingCommonKeyIpHostname = `hostname` - cmdEnvKey = "gf.gtrace" // Configuration key for command argument or environment. + tracingCommonKeyIpIntranet = `ip.intranet` + tracingCommonKeyIpHostname = `hostname` + commandEnvKeyForMaxContentLogSize = "gf.gtrace.maxcontentlogsize" + commandEnvKeyForTracingInternal = "gf.gtrace.tracinginternal" ) var ( intranetIps, _ = gipv4.GetIntranetIpArray() intranetIpStr = strings.Join(intranetIps, ",") hostname, _ = os.Hostname() + tracingInternal = true // tracingInternal enables tracing for internal type spans. tracingMaxContentLogSize = 256 * 1024 // Max log size for request and response body, especially for HTTP/RPC request. // defaultTextMapPropagator is the default propagator for context propagation between peers. defaultTextMapPropagator = propagation.NewCompositeTextMapPropagator( @@ -41,12 +42,18 @@ var ( ) func init() { - if maxContentLogSize := gcmd.GetOptWithEnv(fmt.Sprintf("%s.maxcontentlogsize", cmdEnvKey)).Int(); maxContentLogSize > 0 { + tracingInternal = gcmd.GetOptWithEnv(commandEnvKeyForTracingInternal, true).Bool() + if maxContentLogSize := gcmd.GetOptWithEnv(commandEnvKeyForMaxContentLogSize).Int(); maxContentLogSize > 0 { tracingMaxContentLogSize = maxContentLogSize } CheckSetDefaultTextMapPropagator() } +// IsTracingInternal returns whether tracing spans of internal components. +func IsTracingInternal() bool { + return tracingInternal +} + // MaxContentLogSize returns the max log size for request and response body, especially for HTTP/RPC request. func MaxContentLogSize() int { return tracingMaxContentLogSize diff --git a/net/gudp/gudp_server.go b/net/gudp/gudp_server.go index db7dca6a4..4987787c9 100644 --- a/net/gudp/gudp_server.go +++ b/net/gudp/gudp_server.go @@ -7,7 +7,7 @@ package gudp import ( - "errors" + "github.com/gogf/gf/errors/gerror" "net" "github.com/gogf/gf/container/gmap" @@ -78,7 +78,7 @@ func (s *Server) Close() error { // Run starts listening UDP connection. func (s *Server) Run() error { if s.handler == nil { - err := errors.New("start running failed: socket handler not defined") + err := gerror.New("start running failed: socket handler not defined") glog.Error(err) return err } diff --git a/os/gbuild/gbuild.go b/os/gbuild/gbuild.go index 5447c0cb7..c44cd57fa 100644 --- a/os/gbuild/gbuild.go +++ b/os/gbuild/gbuild.go @@ -8,6 +8,7 @@ package gbuild import ( + "context" "github.com/gogf/gf" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/encoding/gbase64" @@ -26,13 +27,13 @@ func init() { if builtInVarStr != "" { err := json.UnmarshalUseNumber(gbase64.MustDecodeString(builtInVarStr), &builtInVarMap) if err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } builtInVarMap["gfVersion"] = gf.VERSION builtInVarMap["goVersion"] = runtime.Version() - intlog.Printf("build variables: %+v", builtInVarMap) + intlog.Printf(context.TODO(), "build variables: %+v", builtInVarMap) } else { - intlog.Print("no build variables") + intlog.Print(context.TODO(), "no build variables") } } diff --git a/os/gcfg/gcfg.go b/os/gcfg/gcfg.go index d9eac9f2b..ec8bf5153 100644 --- a/os/gcfg/gcfg.go +++ b/os/gcfg/gcfg.go @@ -8,6 +8,7 @@ package gcfg import ( + "context" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/gmap" "github.com/gogf/gf/internal/intlog" @@ -23,10 +24,11 @@ type Config struct { } const ( - DefaultName = "config" // DefaultName is the default group name for instance usage. - DefaultConfigFile = "config.toml" // DefaultConfigFile is the default configuration file name. - cmdEnvKey = "gf.gcfg" // cmdEnvKey is the configuration key for command argument or environment. - errorPrintKey = "gf.gcfg.errorprint" // errorPrintKey is used to specify the key controlling error printing to stdout. + DefaultName = "config" // DefaultName is the default group name for instance usage. + DefaultConfigFile = "config.toml" // DefaultConfigFile is the default configuration file name. + commandEnvKeyForFile = "gf.gcfg.file" // commandEnvKeyForFile is the configuration key for command argument or environment configuring file name. + commandEnvKeyForPath = "gf.gcfg.path" // commandEnvKeyForPath is the configuration key for command argument or environment configuring directory path. + commandEnvKeyForErrorPrint = "gf.gcfg.errorprint" // commandEnvKeyForErrorPrint is used to specify the key controlling error printing to stdout. ) var ( @@ -81,7 +83,7 @@ func RemoveContent(file ...string) { } }) - intlog.Printf(`RemoveContent: %s`, name) + intlog.Printf(context.TODO(), `RemoveContent: %s`, name) } // ClearContent removes all global configuration contents. @@ -94,10 +96,10 @@ func ClearContent() { } }) - intlog.Print(`RemoveConfig`) + intlog.Print(context.TODO(), `RemoveConfig`) } // errorPrint checks whether printing error to stdout. func errorPrint() bool { - return gcmd.GetOptWithEnv(errorPrintKey, true).Bool() + return gcmd.GetOptWithEnv(commandEnvKeyForErrorPrint, true).Bool() } diff --git a/os/gcfg/gcfg_config.go b/os/gcfg/gcfg_config.go index 6034f1f88..425a09df5 100644 --- a/os/gcfg/gcfg_config.go +++ b/os/gcfg/gcfg_config.go @@ -8,7 +8,7 @@ package gcfg import ( "bytes" - "errors" + "context" "fmt" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/gmap" @@ -33,7 +33,7 @@ func New(file ...string) *Config { name = file[0] } else { // Custom default configuration file name from command line or environment. - if customFile := gcmd.GetOptWithEnv(fmt.Sprintf("%s.file", cmdEnvKey)).String(); customFile != "" { + if customFile := gcmd.GetOptWithEnv(commandEnvKeyForFile).String(); customFile != "" { name = customFile } } @@ -43,7 +43,7 @@ func New(file ...string) *Config { jsonMap: gmap.NewStrAnyMap(true), } // Customized dir path from env/cmd. - if customPath := gcmd.GetOptWithEnv(fmt.Sprintf("%s.path", cmdEnvKey)).String(); customPath != "" { + if customPath := gcmd.GetOptWithEnv(commandEnvKeyForPath).String(); customPath != "" { if gfile.Exists(customPath) { _ = c.SetPath(customPath) } else { @@ -54,20 +54,20 @@ func New(file ...string) *Config { } else { // Dir path of working dir. if err := c.AddPath(gfile.Pwd()); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } // Dir path of main package. if mainPath := gfile.MainPkgPath(); mainPath != "" && gfile.Exists(mainPath) { if err := c.AddPath(mainPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } // Dir path of binary. if selfPath := gfile.SelfDir(); selfPath != "" && gfile.Exists(selfPath) { if err := c.AddPath(selfPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } } @@ -142,7 +142,7 @@ func (c *Config) SetPath(path string) error { } else { buffer.WriteString(fmt.Sprintf(`[gcfg] SetPath failed: path "%s" does not exist`, path)) } - err := errors.New(buffer.String()) + err := gerror.New(buffer.String()) if errorPrint() { glog.Error(err) } @@ -163,7 +163,7 @@ func (c *Config) SetPath(path string) error { c.jsonMap.Clear() c.searchPaths.Clear() c.searchPaths.Append(realPath) - intlog.Print("SetPath:", realPath) + intlog.Print(context.TODO(), "SetPath:", realPath) return nil } @@ -237,7 +237,7 @@ func (c *Config) AddPath(path string) error { return nil } c.searchPaths.Append(realPath) - intlog.Print("AddPath:", realPath) + intlog.Print(context.TODO(), "AddPath:", realPath) return nil } diff --git a/os/gcfg/gcfg_config_api.go b/os/gcfg/gcfg_config_api.go index 1fafb391b..7e1f0625e 100644 --- a/os/gcfg/gcfg_config_api.go +++ b/os/gcfg/gcfg_config_api.go @@ -7,7 +7,7 @@ package gcfg import ( - "errors" + "github.com/gogf/gf/errors/gerror" "time" "github.com/gogf/gf/encoding/gjson" @@ -295,16 +295,7 @@ func (c *Config) GetStruct(pattern string, pointer interface{}, mapping ...map[s if j := c.getJson(); j != nil { return j.GetStruct(pattern, pointer, mapping...) } - return errors.New("configuration not found") -} - -// GetStructDeep does GetStruct recursively. -// Deprecated, use GetStruct instead. -func (c *Config) GetStructDeep(pattern string, pointer interface{}, mapping ...map[string]string) error { - if j := c.getJson(); j != nil { - return j.GetStructDeep(pattern, pointer, mapping...) - } - return errors.New("configuration not found") + return gerror.New("configuration not found") } // GetStructs converts any slice to given struct slice. @@ -312,16 +303,7 @@ func (c *Config) GetStructs(pattern string, pointer interface{}, mapping ...map[ if j := c.getJson(); j != nil { return j.GetStructs(pattern, pointer, mapping...) } - return errors.New("configuration not found") -} - -// GetStructsDeep converts any slice to given struct slice recursively. -// Deprecated, use GetStructs instead. -func (c *Config) GetStructsDeep(pattern string, pointer interface{}, mapping ...map[string]string) error { - if j := c.getJson(); j != nil { - return j.GetStructsDeep(pattern, pointer, mapping...) - } - return errors.New("configuration not found") + return gerror.New("configuration not found") } // GetMapToMap retrieves the value by specified `pattern` and converts it to specified map variable. @@ -330,7 +312,7 @@ func (c *Config) GetMapToMap(pattern string, pointer interface{}, mapping ...map if j := c.getJson(); j != nil { return j.GetMapToMap(pattern, pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } // GetMapToMaps retrieves the value by specified `pattern` and converts it to specified map slice @@ -340,7 +322,7 @@ func (c *Config) GetMapToMaps(pattern string, pointer interface{}, mapping ...ma if j := c.getJson(); j != nil { return j.GetMapToMaps(pattern, pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } // GetMapToMapsDeep retrieves the value by specified `pattern` and converts it to specified map slice @@ -350,88 +332,60 @@ func (c *Config) GetMapToMapsDeep(pattern string, pointer interface{}, mapping . if j := c.getJson(); j != nil { return j.GetMapToMapsDeep(pattern, pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } -// ToMap converts current Json object to map[string]interface{}. -// It returns nil if fails. -func (c *Config) ToMap() map[string]interface{} { +// Map converts current Json object to map[string]interface{}. It returns nil if fails. +func (c *Config) Map() map[string]interface{} { if j := c.getJson(); j != nil { - return j.ToMap() + return j.Map() } return nil } -// ToArray converts current Json object to []interface{}. +// Array converts current Json object to []interface{}. // It returns nil if fails. -func (c *Config) ToArray() []interface{} { +func (c *Config) Array() []interface{} { if j := c.getJson(); j != nil { - return j.ToArray() + return j.Array() } return nil } -// ToStruct converts current Json object to specified object. +// Struct converts current Json object to specified object. // The `pointer` should be a pointer type of *struct. -func (c *Config) ToStruct(pointer interface{}, mapping ...map[string]string) error { +func (c *Config) Struct(pointer interface{}, mapping ...map[string]string) error { if j := c.getJson(); j != nil { - return j.ToStruct(pointer, mapping...) + return j.Struct(pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } -// ToStructDeep converts current Json object to specified object recursively. -// The `pointer` should be a pointer type of *struct. -func (c *Config) ToStructDeep(pointer interface{}, mapping ...map[string]string) error { - if j := c.getJson(); j != nil { - return j.ToStructDeep(pointer, mapping...) - } - return errors.New("configuration not found") -} - -// ToStructs converts current Json object to specified object slice. +// Structs converts current Json object to specified object slice. // The `pointer` should be a pointer type of []struct/*struct. -func (c *Config) ToStructs(pointer interface{}, mapping ...map[string]string) error { +func (c *Config) Structs(pointer interface{}, mapping ...map[string]string) error { if j := c.getJson(); j != nil { - return j.ToStructs(pointer, mapping...) + return j.Structs(pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } -// ToStructsDeep converts current Json object to specified object slice recursively. -// The `pointer` should be a pointer type of []struct/*struct. -func (c *Config) ToStructsDeep(pointer interface{}, mapping ...map[string]string) error { - if j := c.getJson(); j != nil { - return j.ToStructsDeep(pointer, mapping...) - } - return errors.New("configuration not found") -} - -// ToMapToMap converts current Json object to specified map variable. +// MapToMap converts current Json object to specified map variable. // The parameter of `pointer` should be type of *map. -func (c *Config) ToMapToMap(pointer interface{}, mapping ...map[string]string) error { +func (c *Config) MapToMap(pointer interface{}, mapping ...map[string]string) error { if j := c.getJson(); j != nil { - return j.ToMapToMap(pointer, mapping...) + return j.MapToMap(pointer, mapping...) } - return errors.New("configuration not found") + return gerror.New("configuration not found") } -// ToMapToMaps converts current Json object to specified map variable slice. +// MapToMaps converts current Json object to specified map variable slice. // The parameter of `pointer` should be type of []map/*map. -func (c *Config) ToMapToMaps(pointer interface{}, mapping ...map[string]string) error { +func (c *Config) MapToMaps(pointer interface{}, mapping ...map[string]string) error { if j := c.getJson(); j != nil { - return j.ToMapToMaps(pointer, mapping...) + return j.MapToMaps(pointer, mapping...) } - return errors.New("configuration not found") -} - -// ToMapToMapsDeep converts current Json object to specified map variable slice recursively. -// The parameter of `pointer` should be type of []map/*map. -func (c *Config) ToMapToMapsDeep(pointer interface{}, mapping ...map[string]string) error { - if j := c.getJson(); j != nil { - return j.ToMapToMapsDeep(pointer, mapping...) - } - return errors.New("configuration not found") + return gerror.New("configuration not found") } // Clear removes all parsed configuration files content cache, diff --git a/os/gcmd/gcmd_handler.go b/os/gcmd/gcmd_handler.go index 2bcb7a905..73d6c73c3 100644 --- a/os/gcmd/gcmd_handler.go +++ b/os/gcmd/gcmd_handler.go @@ -8,13 +8,13 @@ package gcmd import ( - "errors" + "github.com/gogf/gf/errors/gerror" ) // BindHandle registers callback function <f> with <cmd>. func BindHandle(cmd string, f func()) error { if _, ok := defaultCommandFuncMap[cmd]; ok { - return errors.New("duplicated handle for command:" + cmd) + return gerror.New("duplicated handle for command:" + cmd) } else { defaultCommandFuncMap[cmd] = f } @@ -37,7 +37,7 @@ func RunHandle(cmd string) error { if handle, ok := defaultCommandFuncMap[cmd]; ok { handle() } else { - return errors.New("no handle found for command:" + cmd) + return gerror.New("no handle found for command:" + cmd) } return nil } @@ -49,10 +49,10 @@ func AutoRun() error { if handle, ok := defaultCommandFuncMap[cmd]; ok { handle() } else { - return errors.New("no handle found for command:" + cmd) + return gerror.New("no handle found for command:" + cmd) } } else { - return errors.New("no command found") + return gerror.New("no command found") } return nil } diff --git a/os/gcmd/gcmd_parser.go b/os/gcmd/gcmd_parser.go index f18e5093e..1ac003ae2 100644 --- a/os/gcmd/gcmd_parser.go +++ b/os/gcmd/gcmd_parser.go @@ -8,15 +8,13 @@ package gcmd import ( - "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "os" "strings" "github.com/gogf/gf/text/gstr" - "errors" - "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/text/gregex" @@ -96,7 +94,7 @@ func ParseWithArgs(args []string, supportedOptions map[string]bool, strict ...bo i++ continue } else if parser.strict { - return nil, errors.New(fmt.Sprintf(`invalid option '%s'`, args[i])) + return nil, gerror.Newf(`invalid option '%s'`, args[i]) } } } diff --git a/os/gcmd/gcmd_parser_handler.go b/os/gcmd/gcmd_parser_handler.go index a2ab79eee..6e61712f4 100644 --- a/os/gcmd/gcmd_parser_handler.go +++ b/os/gcmd/gcmd_parser_handler.go @@ -8,20 +8,20 @@ package gcmd import ( - "errors" + "github.com/gogf/gf/errors/gerror" ) // BindHandle registers callback function <f> with <cmd>. func (p *Parser) BindHandle(cmd string, f func()) error { if _, ok := p.commandFuncMap[cmd]; ok { - return errors.New("duplicated handle for command:" + cmd) + return gerror.New("duplicated handle for command:" + cmd) } else { p.commandFuncMap[cmd] = f } return nil } -// BindHandle registers callback function with map <m>. +// BindHandleMap registers callback function with map <m>. func (p *Parser) BindHandleMap(m map[string]func()) error { var err error for k, v := range m { @@ -37,7 +37,7 @@ func (p *Parser) RunHandle(cmd string) error { if handle, ok := p.commandFuncMap[cmd]; ok { handle() } else { - return errors.New("no handle found for command:" + cmd) + return gerror.New("no handle found for command:" + cmd) } return nil } @@ -49,10 +49,10 @@ func (p *Parser) AutoRun() error { if handle, ok := p.commandFuncMap[cmd]; ok { handle() } else { - return errors.New("no handle found for command:" + cmd) + return gerror.New("no handle found for command:" + cmd) } } else { - return errors.New("no command found") + return gerror.New("no command found") } return nil } diff --git a/os/gcron/gcron_cron.go b/os/gcron/gcron_cron.go index 28e2b8522..0e148b970 100644 --- a/os/gcron/gcron_cron.go +++ b/os/gcron/gcron_cron.go @@ -7,8 +7,7 @@ package gcron import ( - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "time" "github.com/gogf/gf/container/garray" @@ -63,7 +62,7 @@ func (c *Cron) GetLogLevel() int { func (c *Cron) Add(pattern string, job func(), name ...string) (*Entry, error) { if len(name) > 0 { if c.Search(name[0]) != nil { - return nil, errors.New(fmt.Sprintf(`cron job "%s" already exists`, name[0])) + return nil, gerror.Newf(`cron job "%s" already exists`, name[0]) } } return c.addEntry(pattern, job, false, name...) diff --git a/os/gcron/gcron_schedule.go b/os/gcron/gcron_schedule.go index 017875a1e..a960843e3 100644 --- a/os/gcron/gcron_schedule.go +++ b/os/gcron/gcron_schedule.go @@ -7,8 +7,7 @@ package gcron import ( - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" "strconv" "strings" @@ -91,7 +90,7 @@ func newSchedule(pattern string) (*cronSchedule, error) { }, nil } } else { - return nil, errors.New(fmt.Sprintf(`invalid pattern: "%s"`, pattern)) + return nil, gerror.Newf(`invalid pattern: "%s"`, pattern) } } // Handle the common cron pattern, like: @@ -140,7 +139,7 @@ func newSchedule(pattern string) (*cronSchedule, error) { } return schedule, nil } else { - return nil, errors.New(fmt.Sprintf(`invalid pattern: "%s"`, pattern)) + return nil, gerror.Newf(`invalid pattern: "%s"`, pattern) } } @@ -157,7 +156,7 @@ func parseItem(item string, min int, max int, allowQuestionMark bool) (map[int]s intervalArray := strings.Split(item, "/") if len(intervalArray) == 2 { if i, err := strconv.Atoi(intervalArray[1]); err != nil { - return nil, errors.New(fmt.Sprintf(`invalid pattern item: "%s"`, item)) + return nil, gerror.Newf(`invalid pattern item: "%s"`, item) } else { interval = i } @@ -179,7 +178,7 @@ func parseItem(item string, min int, max int, allowQuestionMark bool) (map[int]s // Eg: */5 if rangeArray[0] != "*" { if i, err := parseItemValue(rangeArray[0], fieldType); err != nil { - return nil, errors.New(fmt.Sprintf(`invalid pattern item: "%s"`, item)) + return nil, gerror.Newf(`invalid pattern item: "%s"`, item) } else { rangeMin = i rangeMax = i @@ -187,7 +186,7 @@ func parseItem(item string, min int, max int, allowQuestionMark bool) (map[int]s } if len(rangeArray) == 2 { if i, err := parseItemValue(rangeArray[1], fieldType); err != nil { - return nil, errors.New(fmt.Sprintf(`invalid pattern item: "%s"`, item)) + return nil, gerror.Newf(`invalid pattern item: "%s"`, item) } else { rangeMax = i } @@ -221,7 +220,7 @@ func parseItemValue(value string, fieldType byte) (int, error) { } } } - return 0, errors.New(fmt.Sprintf(`invalid pattern value: "%s"`, value)) + return 0, gerror.Newf(`invalid pattern value: "%s"`, value) } // meet checks if the given time <t> meets the runnable point for the job. diff --git a/os/gfile/gfile_cache.go b/os/gfile/gfile_cache.go index 758d85495..e1354f19f 100644 --- a/os/gfile/gfile_cache.go +++ b/os/gfile/gfile_cache.go @@ -14,19 +14,19 @@ import ( ) const ( - // Default expire time for file content caching in seconds. - gDEFAULT_CACHE_EXPIRE = time.Minute + defaultCacheExpire = time.Minute // defaultCacheExpire is the expire time for file content caching in seconds. + commandEnvKeyForCache = "gf.gfile.cache" // commandEnvKeyForCache is the configuration key for command argument or environment configuring cache expire duration. ) var ( // Default expire time for file content caching. - cacheExpire = gcmd.GetOptWithEnv("gf.gfile.cache", gDEFAULT_CACHE_EXPIRE).Duration() + cacheExpire = gcmd.GetOptWithEnv(commandEnvKeyForCache, defaultCacheExpire).Duration() // internalCache is the memory cache for internal usage. internalCache = gcache.New() ) -// GetContents returns string content of given file by <path> from cache. +// GetContentsWithCache returns string content of given file by <path> from cache. // If there's no content in the cache, it will read it from disk file specified by <path>. // The parameter <expire> specifies the caching time for this file content in seconds. func GetContentsWithCache(path string, duration ...time.Duration) string { @@ -62,5 +62,5 @@ func GetBytesWithCache(path string, duration ...time.Duration) []byte { // cacheKey produces the cache key for gcache. func cacheKey(path string) string { - return "gf.gfile.cache:" + path + return commandEnvKeyForCache + path } diff --git a/os/gfsnotify/gfsnotify.go b/os/gfsnotify/gfsnotify.go index 619454189..0224b8ba0 100644 --- a/os/gfsnotify/gfsnotify.go +++ b/os/gfsnotify/gfsnotify.go @@ -8,9 +8,9 @@ package gfsnotify import ( - "errors" - "fmt" + "context" "github.com/gogf/gf/container/gset" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "sync" "time" @@ -88,7 +88,7 @@ func New() (*Watcher, error) { if watcher, err := fsnotify.NewWatcher(); err == nil { w.watcher = watcher } else { - intlog.Printf("New watcher failed: %v", err) + intlog.Printf(context.TODO(), "New watcher failed: %v", err) return nil, err } w.watchLoop() @@ -139,7 +139,7 @@ func RemoveCallback(callbackId int) error { callback = r.(*Callback) } if callback == nil { - return errors.New(fmt.Sprintf(`callback for id %d not found`, callbackId)) + return gerror.Newf(`callback for id %d not found`, callbackId) } w.RemoveCallback(callbackId) return nil diff --git a/os/gfsnotify/gfsnotify_watcher.go b/os/gfsnotify/gfsnotify_watcher.go index 37067a988..3219c26c8 100644 --- a/os/gfsnotify/gfsnotify_watcher.go +++ b/os/gfsnotify/gfsnotify_watcher.go @@ -7,8 +7,8 @@ package gfsnotify import ( - "errors" - "fmt" + "context" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/container/glist" @@ -45,9 +45,9 @@ func (w *Watcher) AddOnce(name, path string, callbackFunc func(event *Event), re for _, subPath := range fileAllDirs(path) { if fileIsDir(subPath) { if err := w.watcher.Add(subPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("watcher adds monitor for: %s", subPath) + intlog.Printf(context.TODO(), "watcher adds monitor for: %s", subPath) } } } @@ -65,7 +65,7 @@ func (w *Watcher) AddOnce(name, path string, callbackFunc func(event *Event), re func (w *Watcher) addWithCallbackFunc(name, path string, callbackFunc func(event *Event), recursive ...bool) (callback *Callback, err error) { // Check and convert the given path to absolute path. if t := fileRealPath(path); t == "" { - return nil, errors.New(fmt.Sprintf(`"%s" does not exist`, path)) + return nil, gerror.Newf(`"%s" does not exist`, path) } else { path = t } @@ -93,9 +93,9 @@ func (w *Watcher) addWithCallbackFunc(name, path string, callbackFunc func(event }) // Add the path to underlying monitor. if err := w.watcher.Add(path); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("watcher adds monitor for: %s", path) + intlog.Printf(context.TODO(), "watcher adds monitor for: %s", path) } // Add the callback to global callback map. callbackIdMap.Set(callback.Id, callback) @@ -108,7 +108,7 @@ func (w *Watcher) addWithCallbackFunc(name, path string, callbackFunc func(event func (w *Watcher) Close() { w.events.Close() if err := w.watcher.Close(); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } close(w.closeChan) } @@ -131,7 +131,7 @@ func (w *Watcher) Remove(path string) error { for _, subPath := range subPaths { if w.checkPathCanBeRemoved(subPath) { if err := w.watcher.Remove(subPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } } diff --git a/os/gfsnotify/gfsnotify_watcher_loop.go b/os/gfsnotify/gfsnotify_watcher_loop.go index 531225051..2618239ce 100644 --- a/os/gfsnotify/gfsnotify_watcher_loop.go +++ b/os/gfsnotify/gfsnotify_watcher_loop.go @@ -7,6 +7,7 @@ package gfsnotify import ( + "context" "github.com/gogf/gf/container/glist" "github.com/gogf/gf/internal/intlog" ) @@ -34,7 +35,7 @@ func (w *Watcher) watchLoop() { }, repeatEventFilterDuration) case err := <-w.watcher.Errors: - intlog.Error(err) + intlog.Error(context.TODO(), err) } } }() @@ -60,9 +61,9 @@ func (w *Watcher) eventLoop() { // It adds the path back to monitor. // We need no worry about the repeat adding. if err := w.watcher.Add(event.Path); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("fake remove event, watcher re-adds monitor for: %s", event.Path) + intlog.Printf(context.TODO(), "fake remove event, watcher re-adds monitor for: %s", event.Path) } // Change the event to RENAME, which means it renames itself to its origin name. event.Op = RENAME @@ -76,9 +77,9 @@ func (w *Watcher) eventLoop() { // It might lost the monitoring for the path, so we add the path back to monitor. // We need no worry about the repeat adding. if err := w.watcher.Add(event.Path); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("fake rename event, watcher re-adds monitor for: %s", event.Path) + intlog.Printf(context.TODO(), "fake rename event, watcher re-adds monitor for: %s", event.Path) } // Change the event to CHMOD. event.Op = CHMOD @@ -94,18 +95,18 @@ func (w *Watcher) eventLoop() { for _, subPath := range fileAllDirs(event.Path) { if fileIsDir(subPath) { if err := w.watcher.Add(subPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("folder creation event, watcher adds monitor for: %s", subPath) + intlog.Printf(context.TODO(), "folder creation event, watcher adds monitor for: %s", subPath) } } } } else { // If it's a file, it directly adds it to monitor. if err := w.watcher.Add(event.Path); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } else { - intlog.Printf("file creation event, watcher adds monitor for: %s", event.Path) + intlog.Printf(context.TODO(), "file creation event, watcher adds monitor for: %s", event.Path) } } diff --git a/os/glog/glog.go b/os/glog/glog.go index a465a5b43..afa420db8 100644 --- a/os/glog/glog.go +++ b/os/glog/glog.go @@ -12,12 +12,16 @@ import ( "github.com/gogf/gf/os/grpool" ) +const ( + commandEnvKeyForDebug = "gf.glog.debug" +) + var ( // Default logger object, for package method usage. logger = New() // Goroutine pool for async logging output. - // It uses only one asynchronize worker to ensure log sequence. + // It uses only one asynchronous worker to ensure log sequence. asyncPool = grpool.New(1) // defaultDebug enables debug level or not in default, @@ -26,7 +30,7 @@ var ( ) func init() { - defaultDebug = gcmd.GetOptWithEnv("gf.glog.debug", true).Bool() + defaultDebug = gcmd.GetOptWithEnv(commandEnvKeyForDebug, true).Bool() SetDebug(defaultDebug) } diff --git a/os/glog/glog_logger.go b/os/glog/glog_logger.go index 9bb943fc9..8e5f11ed3 100644 --- a/os/glog/glog_logger.go +++ b/os/glog/glog_logger.go @@ -38,12 +38,13 @@ type Logger struct { } const ( - defaultFileFormat = `{Y-m-d}.log` - defaultFileFlags = os.O_CREATE | os.O_WRONLY | os.O_APPEND - defaultFilePerm = os.FileMode(0666) - defaultFileExpire = time.Minute - pathFilterKey = "/os/glog/glog" - mustWithColor = true + defaultFileFormat = `{Y-m-d}.log` + defaultFileFlags = os.O_CREATE | os.O_WRONLY | os.O_APPEND + defaultFilePerm = os.FileMode(0666) + defaultFileExpire = time.Minute + pathFilterKey = "/os/glog/glog" + memoryLockPrefixForPrintingToFile = "glog.printToFile:" + mustWithColor = true ) const ( @@ -63,7 +64,6 @@ func New() *Logger { init: gtype.NewBool(), config: DefaultConfig(), } - logger.config.Handlers = []Handler{defaultHandler} return logger } @@ -104,11 +104,11 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) { if p.parent != nil { p = p.parent } - if !p.init.Val() && p.init.Cas(false, true) { - // It just initializes once for each logger. - if p.config.RotateSize > 0 || p.config.RotateExpire > 0 { + // It just initializes once for each logger. + if p.config.RotateSize > 0 || p.config.RotateExpire > 0 { + if !p.init.Val() && p.init.Cas(false, true) { gtimer.AddOnce(p.config.RotateCheckInterval, p.rotateChecksTimely) - intlog.Printf("logger rotation initialized: every %s", p.config.RotateCheckInterval.String()) + intlog.Printf(ctx, "logger rotation initialized: every %s", p.config.RotateCheckInterval.String()) } } @@ -170,7 +170,7 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) { // Tracing values. spanCtx := trace.SpanContextFromContext(ctx) if traceId := spanCtx.TraceID(); traceId.IsValid() { - input.CtxStr = "{TraceID:" + traceId.String() + "}" + input.CtxStr = "{" + traceId.String() + "}" } // Context values. if len(l.config.CtxKeys) > 0 { @@ -212,7 +212,7 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) { input.Next() }) if err != nil { - intlog.Error(err) + intlog.Error(ctx, err) } } else { input.Next() @@ -224,7 +224,7 @@ func (l *Logger) printToWriter(ctx context.Context, input *HandlerInput) { if l.config.Writer == nil { // Output content to disk file. if l.config.Path != "" { - l.printToFile(input.Time, input.Buffer()) + l.printToFile(ctx, input.Time, input.Buffer()) } // Allow output to stdout? if l.config.StdoutPrint { @@ -233,16 +233,16 @@ func (l *Logger) printToWriter(ctx context.Context, input *HandlerInput) { } else { if _, err := l.config.Writer.Write(input.Buffer().Bytes()); err != nil { // panic(err) - intlog.Error(err) + intlog.Error(ctx, err) } } } // printToFile outputs logging content to disk file. -func (l *Logger) printToFile(t time.Time, buffer *bytes.Buffer) { +func (l *Logger) printToFile(ctx context.Context, t time.Time, buffer *bytes.Buffer) { var ( logFilePath = l.getFilePath(t) - memoryLockKey = "glog.printToFile:" + logFilePath + memoryLockKey = memoryLockPrefixForPrintingToFile + logFilePath ) gmlock.Lock(memoryLockKey) defer gmlock.Unlock(memoryLockKey) @@ -254,20 +254,20 @@ func (l *Logger) printToFile(t time.Time, buffer *bytes.Buffer) { } } // Logging content outputting to disk file. - if file := l.getFilePointer(logFilePath); file == nil { - intlog.Errorf(`got nil file pointer for: %s`, logFilePath) + if file := l.getFilePointer(ctx, logFilePath); file == nil { + intlog.Errorf(ctx, `got nil file pointer for: %s`, logFilePath) } else { if _, err := file.Write(buffer.Bytes()); err != nil { - intlog.Error(err) + intlog.Error(ctx, err) } if err := file.Close(); err != nil { - intlog.Error(err) + intlog.Error(ctx, err) } } } // getFilePointer retrieves and returns a file pointer from file pool. -func (l *Logger) getFilePointer(path string) *gfpool.File { +func (l *Logger) getFilePointer(ctx context.Context, path string) *gfpool.File { file, err := gfpool.Open( path, defaultFileFlags, @@ -276,7 +276,7 @@ func (l *Logger) getFilePointer(path string) *gfpool.File { ) if err != nil { // panic(err) - intlog.Error(err) + intlog.Error(ctx, err) } return file } diff --git a/os/glog/glog_logger_chaining.go b/os/glog/glog_logger_chaining.go index f4ed24201..bb9529023 100644 --- a/os/glog/glog_logger_chaining.go +++ b/os/glog/glog_logger_chaining.go @@ -61,7 +61,7 @@ func (l *Logger) Path(path string) *Logger { if path != "" { if err := logger.SetPath(path); err != nil { // panic(err) - intlog.Error(err) + intlog.Error(l.getCtx(), err) } } return logger @@ -80,7 +80,7 @@ func (l *Logger) Cat(category string) *Logger { if logger.config.Path != "" { if err := logger.SetPath(gfile.Join(logger.config.Path, category)); err != nil { // panic(err) - intlog.Error(err) + intlog.Error(l.getCtx(), err) } } return logger @@ -123,7 +123,7 @@ func (l *Logger) LevelStr(levelStr string) *Logger { } if err := logger.SetLevelStr(levelStr); err != nil { // panic(err) - intlog.Error(err) + intlog.Error(l.getCtx(), err) } return logger } diff --git a/os/glog/glog_logger_config.go b/os/glog/glog_logger_config.go index 620351fd5..d16564662 100644 --- a/os/glog/glog_logger_config.go +++ b/os/glog/glog_logger_config.go @@ -7,8 +7,6 @@ package glog import ( - "errors" - "fmt" "github.com/fatih/color" "io" "strings" @@ -74,18 +72,18 @@ func (l *Logger) SetConfig(config Config) error { // Necessary validation. if config.Path != "" { if err := l.SetPath(config.Path); err != nil { - intlog.Error(err) + intlog.Error(l.ctx, err) return err } } - intlog.Printf("SetConfig: %+v", l.config) + intlog.Printf(l.ctx, "SetConfig: %+v", l.config) return nil } // SetConfigWithMap set configurations with map for the logger. func (l *Logger) SetConfigWithMap(m map[string]interface{}) error { if m == nil || len(m) == 0 { - return errors.New("configuration cannot be empty") + return gerror.New("configuration cannot be empty") } // The m now is a shallow copy of m. // A little tricky, isn't it? @@ -96,7 +94,7 @@ func (l *Logger) SetConfigWithMap(m map[string]interface{}) error { if level, ok := levelStringMap[strings.ToUpper(gconv.String(levelValue))]; ok { m[levelKey] = level } else { - return errors.New(fmt.Sprintf(`invalid level string: %v`, levelValue)) + return gerror.Newf(`invalid level string: %v`, levelValue) } } // Change string configuration to int value for file rotation size. @@ -104,11 +102,10 @@ func (l *Logger) SetConfigWithMap(m map[string]interface{}) error { if rotateSizeValue != nil { m[rotateSizeKey] = gfile.StrToSize(gconv.String(rotateSizeValue)) if m[rotateSizeKey] == -1 { - return errors.New(fmt.Sprintf(`invalid rotate size: %v`, rotateSizeValue)) + return gerror.Newf(`invalid rotate size: %v`, rotateSizeValue) } } - err := gconv.Struct(m, &l.config) - if err != nil { + if err := gconv.Struct(m, &l.config); err != nil { return err } return l.SetConfig(l.config) @@ -210,7 +207,7 @@ func (l *Logger) GetWriter() io.Writer { // SetPath sets the directory path for file logging. func (l *Logger) SetPath(path string) error { if path == "" { - return errors.New("logging path is empty") + return gerror.New("logging path is empty") } if !gfile.Exists(path) { if err := gfile.Mkdir(path); err != nil { @@ -253,5 +250,5 @@ func (l *Logger) SetPrefix(prefix string) { // SetHandlers sets the logging handlers for current logger. func (l *Logger) SetHandlers(handlers ...Handler) { - l.config.Handlers = append(handlers, defaultHandler) + l.config.Handlers = handlers } diff --git a/os/glog/glog_logger_handler.go b/os/glog/glog_logger_handler.go index b632e46bb..53e090ca8 100644 --- a/os/glog/glog_logger_handler.go +++ b/os/glog/glog_logger_handler.go @@ -106,5 +106,8 @@ func (i *HandlerInput) Next() { if len(i.logger.config.Handlers)-1 > i.index { i.index++ i.logger.config.Handlers[i.index](i.Ctx, i) + } else { + // The last handler is the default handler. + defaultHandler(i.Ctx, i) } } diff --git a/os/glog/glog_logger_level.go b/os/glog/glog_logger_level.go index 4cdad5624..f726bd7bb 100644 --- a/os/glog/glog_logger_level.go +++ b/os/glog/glog_logger_level.go @@ -7,9 +7,8 @@ package glog import ( - "errors" - "fmt" "github.com/fatih/color" + "github.com/gogf/gf/errors/gerror" "strings" ) @@ -86,8 +85,10 @@ var levelStringMap = map[string]int{ } // SetLevel sets the logging level. +// Note that levels ` LEVEL_CRIT | LEVEL_PANI | LEVEL_FATA ` cannot be removed for logging content, +// which are automatically added to levels. func (l *Logger) SetLevel(level int) { - l.config.Level = level + l.config.Level = level | LEVEL_CRIT | LEVEL_PANI | LEVEL_FATA } // GetLevel returns the logging level value. @@ -100,7 +101,7 @@ func (l *Logger) SetLevelStr(levelStr string) error { if level, ok := levelStringMap[strings.ToUpper(levelStr)]; ok { l.config.Level = level } else { - return errors.New(fmt.Sprintf(`invalid level string: %s`, levelStr)) + return gerror.Newf(`invalid level string: %s`, levelStr) } return nil } diff --git a/os/glog/glog_logger_rotate.go b/os/glog/glog_logger_rotate.go index e508b8006..697793854 100644 --- a/os/glog/glog_logger_rotate.go +++ b/os/glog/glog_logger_rotate.go @@ -19,6 +19,10 @@ import ( "time" ) +const ( + memoryLockPrefixForRotating = "glog.rotateChecksTimely:" +) + // rotateFileBySize rotates the current logging file according to the // configured rotation size. func (l *Logger) rotateFileBySize(now time.Time) { @@ -27,7 +31,7 @@ func (l *Logger) rotateFileBySize(now time.Time) { } if err := l.doRotateFile(l.getFilePath(now)); err != nil { // panic(err) - intlog.Error(err) + intlog.Error(l.ctx, err) } } @@ -44,7 +48,11 @@ func (l *Logger) doRotateFile(filePath string) error { if err := gfile.Remove(filePath); err != nil { return err } - intlog.Printf(`%d size exceeds, no backups set, remove original logging file: %s`, l.config.RotateSize, filePath) + intlog.Printf( + l.ctx, + `%d size exceeds, no backups set, remove original logging file: %s`, + l.config.RotateSize, filePath, + ) return nil } // Else it creates new backup files. @@ -79,7 +87,7 @@ func (l *Logger) doRotateFile(filePath string) error { if !gfile.Exists(newFilePath) { break } else { - intlog.Printf(`rotation file exists, continue: %s`, newFilePath) + intlog.Printf(l.ctx, `rotation file exists, continue: %s`, newFilePath) } } if err := gfile.Rename(filePath, newFilePath); err != nil { @@ -91,9 +99,11 @@ func (l *Logger) doRotateFile(filePath string) error { // rotateChecksTimely timely checks the backups expiration and the compression. func (l *Logger) rotateChecksTimely() { defer gtimer.AddOnce(l.config.RotateCheckInterval, l.rotateChecksTimely) + // Checks whether file rotation not enabled. if l.config.RotateSize <= 0 && l.config.RotateExpire == 0 { intlog.Printf( + l.ctx, "logging rotation ignore checks: RotateSize: %d, RotateExpire: %s", l.config.RotateSize, l.config.RotateExpire.String(), ) @@ -101,18 +111,21 @@ func (l *Logger) rotateChecksTimely() { } // It here uses memory lock to guarantee the concurrent safety. - memoryLockKey := "glog.rotateChecksTimely:" + l.config.Path + memoryLockKey := memoryLockPrefixForRotating + l.config.Path if !gmlock.TryLock(memoryLockKey) { return } defer gmlock.Unlock(memoryLockKey) var ( - now = time.Now() - pattern = "*.log, *.gz" - files, _ = gfile.ScanDirFile(l.config.Path, pattern, true) + now = time.Now() + pattern = "*.log, *.gz" + files, err = gfile.ScanDirFile(l.config.Path, pattern, true) ) - intlog.Printf("logging rotation start checks: %+v", files) + if err != nil { + intlog.Error(l.ctx, err) + } + intlog.Printf(l.ctx, "logging rotation start checks: %+v", files) // ============================================================= // Rotation of expired file checks. // ============================================================= @@ -131,17 +144,21 @@ func (l *Logger) rotateChecksTimely() { if subDuration > l.config.RotateExpire { expireRotated = true intlog.Printf( + l.ctx, `%v - %v = %v > %v, rotation expire logging file: %s`, now, mtime, subDuration, l.config.RotateExpire, file, ) if err := l.doRotateFile(file); err != nil { - intlog.Error(err) + intlog.Error(l.ctx, err) } } } if expireRotated { // Update the files array. - files, _ = gfile.ScanDirFile(l.config.Path, pattern, true) + files, err = gfile.ScanDirFile(l.config.Path, pattern, true) + if err != nil { + intlog.Error(l.ctx, err) + } } } @@ -165,17 +182,20 @@ func (l *Logger) rotateChecksTimely() { needCompressFileArray.Iterator(func(_ int, path string) bool { err := gcompress.GzipFile(path, path+".gz") if err == nil { - intlog.Printf(`compressed done, remove original logging file: %s`, path) + intlog.Printf(l.ctx, `compressed done, remove original logging file: %s`, path) if err = gfile.Remove(path); err != nil { - intlog.Print(err) + intlog.Print(l.ctx, err) } } else { - intlog.Print(err) + intlog.Print(l.ctx, err) } return true }) // Update the files array. - files, _ = gfile.ScanDirFile(l.config.Path, pattern, true) + files, err = gfile.ScanDirFile(l.config.Path, pattern, true) + if err != nil { + intlog.Error(l.ctx, err) + } } } @@ -192,10 +212,12 @@ func (l *Logger) rotateChecksTimely() { if backupFilesMap[originalLoggingFilePath] == nil { backupFilesMap[originalLoggingFilePath] = garray.NewSortedArray(func(a, b interface{}) int { // Sorted by rotated/backup file mtime. - // The old rotated/backup file is put in the head of array. - file1 := a.(string) - file2 := b.(string) - result := gfile.MTimestampMilli(file1) - gfile.MTimestampMilli(file2) + // The older rotated/backup file is put in the head of array. + var ( + file1 = a.(string) + file2 = b.(string) + result = gfile.MTimestampMilli(file1) - gfile.MTimestampMilli(file2) + ) if result <= 0 { return -1 } @@ -207,18 +229,18 @@ func (l *Logger) rotateChecksTimely() { backupFilesMap[originalLoggingFilePath].Add(file) } } - intlog.Printf(`calculated backup files map: %+v`, backupFilesMap) + intlog.Printf(l.ctx, `calculated backup files map: %+v`, backupFilesMap) for _, array := range backupFilesMap { diff := array.Len() - l.config.RotateBackupLimit for i := 0; i < diff; i++ { path, _ := array.PopLeft() - intlog.Printf(`remove exceeded backup limit file: %s`, path) + intlog.Printf(l.ctx, `remove exceeded backup limit file: %s`, path) if err := gfile.Remove(path.(string)); err != nil { - intlog.Print(err) + intlog.Error(l.ctx, err) } } } - // Backup expiration checks. + // Backups expiration checking. if l.config.RotateBackupExpire > 0 { var ( mtime time.Time @@ -231,11 +253,12 @@ func (l *Logger) rotateChecksTimely() { subDuration = now.Sub(mtime) if subDuration > l.config.RotateBackupExpire { intlog.Printf( + l.ctx, `%v - %v = %v > %v, remove expired backup file: %s`, now, mtime, subDuration, l.config.RotateBackupExpire, path, ) if err := gfile.Remove(path); err != nil { - intlog.Print(err) + intlog.Error(l.ctx, err) } return true } else { diff --git a/os/gproc/gproc_comm.go b/os/gproc/gproc_comm.go index ca9911c83..fc013eec9 100644 --- a/os/gproc/gproc_comm.go +++ b/os/gproc/gproc_comm.go @@ -7,9 +7,9 @@ package gproc import ( - "errors" "fmt" "github.com/gogf/gf/container/gmap" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/net/gtcp" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/util/gconv" @@ -65,7 +65,7 @@ func getConnByPid(pid int) (*gtcp.PoolConn, error) { return nil, err } } - return nil, errors.New(fmt.Sprintf("could not find port for pid: %d", pid)) + return nil, gerror.Newf("could not find port for pid: %d", pid) } // getPortByPid returns the listening port for specified pid. diff --git a/os/gproc/gproc_comm_send.go b/os/gproc/gproc_comm_send.go index b53c5132e..7946c9a81 100644 --- a/os/gproc/gproc_comm_send.go +++ b/os/gproc/gproc_comm_send.go @@ -7,7 +7,7 @@ package gproc import ( - "errors" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/net/gtcp" "io" @@ -46,7 +46,7 @@ func Send(pid int, data []byte, group ...string) error { err = json.UnmarshalUseNumber(result, response) if err == nil { if response.Code != 1 { - err = errors.New(response.Message) + err = gerror.New(response.Message) } } } diff --git a/os/gproc/gproc_process.go b/os/gproc/gproc_process.go index 74bf04344..51fa28d70 100644 --- a/os/gproc/gproc_process.go +++ b/os/gproc/gproc_process.go @@ -7,8 +7,9 @@ package gproc import ( - "errors" + "context" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "os" "os/exec" @@ -100,7 +101,7 @@ func (p *Process) Send(data []byte) error { if p.Process != nil { return Send(p.Process.Pid, data) } - return errors.New("invalid process") + return gerror.New("invalid process") } // Release releases any resources associated with the Process p, @@ -118,12 +119,11 @@ func (p *Process) Kill() error { } if runtime.GOOS != "windows" { if err = p.Process.Release(); err != nil { - intlog.Error(err) - //return err + intlog.Error(context.TODO(), err) } } _, err = p.Process.Wait() - intlog.Error(err) + intlog.Error(context.TODO(), err) //return err return nil } else { diff --git a/os/gproc/gproc_signal.go b/os/gproc/gproc_signal.go index 08d059e00..1f742289f 100644 --- a/os/gproc/gproc_signal.go +++ b/os/gproc/gproc_signal.go @@ -7,6 +7,7 @@ package gproc import ( + "context" "github.com/gogf/gf/internal/intlog" "os" "os/signal" @@ -67,7 +68,7 @@ func Listen() { for { wg := sync.WaitGroup{} sig = <-sigChan - intlog.Printf(`signal received: %s`, sig.String()) + intlog.Printf(context.TODO(), `signal received: %s`, sig.String()) if handlers, ok := signalHandlerMap[sig]; ok { for _, handler := range handlers { wg.Add(1) diff --git a/os/gres/gres_func_zip.go b/os/gres/gres_func_zip.go index bc6e0ac50..706bd02bc 100644 --- a/os/gres/gres_func_zip.go +++ b/os/gres/gres_func_zip.go @@ -8,6 +8,7 @@ package gres import ( "archive/zip" + "context" "github.com/gogf/gf/internal/fileinfo" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" @@ -70,7 +71,7 @@ func doZipPathWriter(path string, exclude string, zipWriter *zip.Writer, prefix headerPrefix = strings.Replace(headerPrefix, "//", "/", -1) for _, file := range files { if exclude == file { - intlog.Printf(`exclude file path: %s`, file) + intlog.Printf(context.TODO(), `exclude file path: %s`, file) continue } err = zipFile(file, headerPrefix+gfile.Dir(file[len(path):]), zipWriter) diff --git a/os/gres/gres_resource.go b/os/gres/gres_resource.go index 2ed7b2e46..cb9213b00 100644 --- a/os/gres/gres_resource.go +++ b/os/gres/gres_resource.go @@ -7,6 +7,7 @@ package gres import ( + "context" "fmt" "github.com/gogf/gf/internal/intlog" "os" @@ -42,7 +43,7 @@ func New() *Resource { func (r *Resource) Add(content string, prefix ...string) error { files, err := UnpackContent(content) if err != nil { - intlog.Printf("Add resource files failed: %v", err) + intlog.Printf(context.TODO(), "Add resource files failed: %v", err) return err } namePrefix := "" @@ -53,7 +54,7 @@ func (r *Resource) Add(content string, prefix ...string) error { files[i].resource = r r.tree.Set(namePrefix+files[i].file.Name, files[i]) } - intlog.Printf("Add %d files to resource manager", r.tree.Size()) + intlog.Printf(context.TODO(), "Add %d files to resource manager", r.tree.Size()) return nil } diff --git a/os/grpool/grpool.go b/os/grpool/grpool.go index 136484367..707eb1304 100644 --- a/os/grpool/grpool.go +++ b/os/grpool/grpool.go @@ -8,8 +8,7 @@ package grpool import ( - "errors" - "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/container/glist" "github.com/gogf/gf/container/gtype" @@ -70,7 +69,7 @@ func Jobs() int { // The job will be executed asynchronously. func (p *Pool) Add(f func()) error { for p.closed.Val() { - return errors.New("pool closed") + return gerror.New("pool closed") } p.list.PushFront(f) // Check whether fork new goroutine or not. @@ -99,7 +98,7 @@ func (p *Pool) AddWithRecover(userFunc func(), recoverFunc ...func(err error)) e defer func() { if err := recover(); err != nil { if len(recoverFunc) > 0 && recoverFunc[0] != nil { - recoverFunc[0](errors.New(fmt.Sprintf(`%v`, err))) + recoverFunc[0](gerror.Newf(`%v`, err)) } } }() diff --git a/os/gsession/gsession_manager.go b/os/gsession/gsession_manager.go index c52095146..7c69c6fe4 100644 --- a/os/gsession/gsession_manager.go +++ b/os/gsession/gsession_manager.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "github.com/gogf/gf/container/gmap" "time" @@ -41,13 +42,14 @@ func New(ttl time.Duration, storage ...Storage) *Manager { // New creates or fetches the session for given session id. // The parameter <sessionId> is optional, it creates a new one if not it's passed // depending on Storage.New. -func (m *Manager) New(sessionId ...string) *Session { +func (m *Manager) New(ctx context.Context, sessionId ...string) *Session { var id string if len(sessionId) > 0 && sessionId[0] != "" { id = sessionId[0] } return &Session{ id: id, + ctx: ctx, manager: m, } } diff --git a/os/gsession/gsession_session.go b/os/gsession/gsession_session.go index 17e764852..2067d5b35 100644 --- a/os/gsession/gsession_session.go +++ b/os/gsession/gsession_session.go @@ -7,7 +7,8 @@ package gsession import ( - "errors" + "context" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "time" @@ -17,10 +18,12 @@ import ( "github.com/gogf/gf/util/gconv" ) -// Session struct for storing single session data, -// which is bound to a single request. +// Session struct for storing single session data, which is bound to a single request. +// The Session struct is the interface with user, but the Storage is the underlying adapter designed interface +// for functionality implements. type Session struct { id string // Session id. + ctx context.Context // Context for current session, note that: one session one context. data *gmap.StrAnyMap // Session data. dirty bool // Used to mark session is modified. start bool // Used to mark session is started. @@ -42,12 +45,12 @@ func (s *Session) init() { // Retrieve memory session data from manager. if r, _ := s.manager.sessionData.Get(s.id); r != nil { s.data = r.(*gmap.StrAnyMap) - intlog.Print("session init data:", s.data) + intlog.Print(s.ctx, "session init data:", s.data) } // Retrieve stored session data from storage. if s.manager.storage != nil { - if s.data, err = s.manager.storage.GetSession(s.id, s.manager.ttl, s.data); err != nil { - intlog.Errorf("session restoring failed for id '%s': %v", s.id, err) + if s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.ttl, s.data); err != nil { + intlog.Errorf(s.ctx, "session restoring failed for id '%s': %v", s.id, err) } } } @@ -57,7 +60,7 @@ func (s *Session) init() { } // Use default session id creating function of storage. if s.id == "" { - s.id = s.manager.storage.New(s.manager.ttl) + s.id = s.manager.storage.New(s.ctx, s.manager.ttl) } // Use default session id creating function. if s.id == "" { @@ -78,11 +81,11 @@ func (s *Session) Close() { size := s.data.Size() if s.manager.storage != nil { if s.dirty { - if err := s.manager.storage.SetSession(s.id, s.data, s.manager.ttl); err != nil { + if err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl); err != nil { panic(err) } } else if size > 0 { - if err := s.manager.storage.UpdateTTL(s.id, s.manager.ttl); err != nil { + if err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl); err != nil { panic(err) } } @@ -96,7 +99,7 @@ func (s *Session) Close() { // Set sets key-value pair to this session. func (s *Session) Set(key string, value interface{}) error { s.init() - if err := s.manager.storage.Set(s.id, key, value, s.manager.ttl); err != nil { + if err := s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil { if err == ErrorDisabled { s.data.Set(key, value) } else { @@ -116,7 +119,7 @@ func (s *Session) Sets(data map[string]interface{}) error { // SetMap batch sets the session using map. func (s *Session) SetMap(data map[string]interface{}) error { s.init() - if err := s.manager.storage.SetMap(s.id, data, s.manager.ttl); err != nil { + if err := s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil { if err == ErrorDisabled { s.data.Sets(data) } else { @@ -134,7 +137,7 @@ func (s *Session) Remove(keys ...string) error { } s.init() for _, key := range keys { - if err := s.manager.storage.Remove(s.id, key); err != nil { + if err := s.manager.storage.Remove(s.ctx, s.id, key); err != nil { if err == ErrorDisabled { s.data.Remove(key) } else { @@ -157,7 +160,7 @@ func (s *Session) RemoveAll() error { return nil } s.init() - if err := s.manager.storage.RemoveAll(s.id); err != nil { + if err := s.manager.storage.RemoveAll(s.ctx, s.id); err != nil { if err == ErrorDisabled { s.data.Clear() } else { @@ -179,7 +182,7 @@ func (s *Session) Id() string { // It returns error if it is called after session starts. func (s *Session) SetId(id string) error { if s.start { - return errors.New("session already started") + return gerror.New("session already started") } s.id = id return nil @@ -189,7 +192,7 @@ func (s *Session) SetId(id string) error { // It returns error if it is called after session starts. func (s *Session) SetIdFunc(f func(ttl time.Duration) string) error { if s.start { - return errors.New("session already started") + return gerror.New("session already started") } s.idFunc = f return nil @@ -200,7 +203,7 @@ func (s *Session) SetIdFunc(f func(ttl time.Duration) string) error { func (s *Session) Map() map[string]interface{} { if s.id != "" { s.init() - if data := s.manager.storage.GetMap(s.id); data != nil { + if data := s.manager.storage.GetMap(s.ctx, s.id); data != nil { return data } return s.data.Map() @@ -212,7 +215,7 @@ func (s *Session) Map() map[string]interface{} { func (s *Session) Size() int { if s.id != "" { s.init() - if size := s.manager.storage.GetSize(s.id); size >= 0 { + if size := s.manager.storage.GetSize(s.ctx, s.id); size >= 0 { return size } return s.data.Size() @@ -239,7 +242,7 @@ func (s *Session) Get(key string, def ...interface{}) interface{} { return nil } s.init() - if v := s.manager.storage.Get(s.id, key); v != nil { + if v := s.manager.storage.Get(s.ctx, s.id, key); v != nil { return v } if v := s.data.Get(key); v != nil { @@ -363,16 +366,6 @@ func (s *Session) GetStruct(key string, pointer interface{}, mapping ...map[stri return gconv.Struct(s.Get(key), pointer, mapping...) } -// Deprecated, use GetStruct instead. -func (s *Session) GetStructDeep(key string, pointer interface{}, mapping ...map[string]string) error { - return gconv.StructDeep(s.Get(key), pointer, mapping...) -} - func (s *Session) GetStructs(key string, pointer interface{}, mapping ...map[string]string) error { return gconv.Structs(s.Get(key), pointer, mapping...) } - -// Deprecated, use GetStructs instead. -func (s *Session) GetStructsDeep(key string, pointer interface{}, mapping ...map[string]string) error { - return gconv.StructsDeep(s.Get(key), pointer, mapping...) -} diff --git a/os/gsession/gsession_storage.go b/os/gsession/gsession_storage.go index 02b72e322..87ef2a669 100644 --- a/os/gsession/gsession_storage.go +++ b/os/gsession/gsession_storage.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "github.com/gogf/gf/container/gmap" "time" ) @@ -15,31 +16,31 @@ import ( type Storage interface { // New creates a custom session id. // This function can be used for custom session creation. - New(ttl time.Duration) (id string) + New(ctx context.Context, ttl time.Duration) (id string) // Get retrieves and returns session value with given key. // It returns nil if the key does not exist in the session. - Get(id string, key string) interface{} + Get(ctx context.Context, id string, key string) interface{} // GetMap retrieves all key-value pairs as map from storage. - GetMap(id string) map[string]interface{} + GetMap(ctx context.Context, id string) map[string]interface{} // GetSize retrieves and returns the size of key-value pairs from storage. - GetSize(id string) int + GetSize(ctx context.Context, id string) int // Set sets one key-value session pair to the storage. // The parameter <ttl> specifies the TTL for the session id. - Set(id string, key string, value interface{}, ttl time.Duration) error + Set(ctx context.Context, id string, key string, value interface{}, ttl time.Duration) error // SetMap batch sets key-value session pairs as map to the storage. // The parameter <ttl> specifies the TTL for the session id. - SetMap(id string, data map[string]interface{}, ttl time.Duration) error + SetMap(ctx context.Context, id string, data map[string]interface{}, ttl time.Duration) error // Remove deletes key with its value from storage. - Remove(id string, key string) error + Remove(ctx context.Context, id string, key string) error // RemoveAll deletes all key-value pairs from storage. - RemoveAll(id string) error + RemoveAll(ctx context.Context, id string) error // GetSession returns the session data as *gmap.StrAnyMap for given session id from storage. // @@ -48,14 +49,14 @@ type Storage interface { // and for some storage it might be nil if memory storage is disabled. // // This function is called ever when session starts. It returns nil if the TTL is exceeded. - GetSession(id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) + GetSession(ctx context.Context, id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) // SetSession updates the data for specified session id. // This function is called ever after session, which is changed dirty, is closed. // This copy all session data map from memory to storage. - SetSession(id string, data *gmap.StrAnyMap, ttl time.Duration) error + SetSession(ctx context.Context, id string, data *gmap.StrAnyMap, ttl time.Duration) error // UpdateTTL updates the TTL for specified session id. // This function is called ever after session, which is not dirty, is closed. - UpdateTTL(id string, ttl time.Duration) error + UpdateTTL(ctx context.Context, id string, ttl time.Duration) error } diff --git a/os/gsession/gsession_storage_file.go b/os/gsession/gsession_storage_file.go index fb0bc7ca6..7bbbaf2ef 100644 --- a/os/gsession/gsession_storage_file.go +++ b/os/gsession/gsession_storage_file.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "github.com/gogf/gf/container/gmap" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" @@ -80,8 +81,8 @@ func (s *StorageFile) updateSessionTimely() { if id = s.updatingIdSet.Pop(); id == "" { break } - if err = s.updateSessionTTl(id); err != nil { - intlog.Error(err) + if err = s.updateSessionTTl(context.TODO(), id); err != nil { + intlog.Error(context.TODO(), err) } } } @@ -104,45 +105,45 @@ func (s *StorageFile) sessionFilePath(id string) string { // New creates a session id. // This function can be used for custom session creation. -func (s *StorageFile) New(ttl time.Duration) (id string) { +func (s *StorageFile) New(ctx context.Context, ttl time.Duration) (id string) { return "" } // Get retrieves session value with given key. // It returns nil if the key does not exist in the session. -func (s *StorageFile) Get(id string, key string) interface{} { +func (s *StorageFile) Get(ctx context.Context, id string, key string) interface{} { return nil } // GetMap retrieves all key-value pairs as map from storage. -func (s *StorageFile) GetMap(id string) map[string]interface{} { +func (s *StorageFile) GetMap(ctx context.Context, id string) map[string]interface{} { return nil } // GetSize retrieves the size of key-value pairs from storage. -func (s *StorageFile) GetSize(id string) int { +func (s *StorageFile) GetSize(ctx context.Context, id string) int { return -1 } // Set sets key-value session pair to the storage. // The parameter <ttl> specifies the TTL for the session id (not for the key-value pair). -func (s *StorageFile) Set(id string, key string, value interface{}, ttl time.Duration) error { +func (s *StorageFile) Set(ctx context.Context, id string, key string, value interface{}, ttl time.Duration) error { return ErrorDisabled } // SetMap batch sets key-value session pairs with map to the storage. // The parameter <ttl> specifies the TTL for the session id(not for the key-value pair). -func (s *StorageFile) SetMap(id string, data map[string]interface{}, ttl time.Duration) error { +func (s *StorageFile) SetMap(ctx context.Context, id string, data map[string]interface{}, ttl time.Duration) error { return ErrorDisabled } // Remove deletes key with its value from storage. -func (s *StorageFile) Remove(id string, key string) error { +func (s *StorageFile) Remove(ctx context.Context, id string, key string) error { return ErrorDisabled } // RemoveAll deletes all key-value pairs from storage. -func (s *StorageFile) RemoveAll(id string) error { +func (s *StorageFile) RemoveAll(ctx context.Context, id string) error { return ErrorDisabled } @@ -153,11 +154,10 @@ func (s *StorageFile) RemoveAll(id string) error { // and for some storage it might be nil if memory storage is disabled. // // This function is called ever when session starts. -func (s *StorageFile) GetSession(id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { +func (s *StorageFile) GetSession(ctx context.Context, id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { if data != nil { return data, nil } - //intlog.Printf("StorageFile.GetSession: %s, %v", id, ttl) path := s.sessionFilePath(id) content := gfile.GetBytes(path) if len(content) > 8 { @@ -189,8 +189,8 @@ func (s *StorageFile) GetSession(id string, ttl time.Duration, data *gmap.StrAny // SetSession updates the data map for specified session id. // This function is called ever after session, which is changed dirty, is closed. // This copy all session data map from memory to storage. -func (s *StorageFile) SetSession(id string, data *gmap.StrAnyMap, ttl time.Duration) error { - intlog.Printf("StorageFile.SetSession: %s, %v, %v", id, data, ttl) +func (s *StorageFile) SetSession(ctx context.Context, id string, data *gmap.StrAnyMap, ttl time.Duration) error { + intlog.Printf(ctx, "StorageFile.SetSession: %s, %v, %v", id, data, ttl) path := s.sessionFilePath(id) content, err := json.Marshal(data) if err != nil { @@ -222,8 +222,8 @@ func (s *StorageFile) SetSession(id string, data *gmap.StrAnyMap, ttl time.Durat // UpdateTTL updates the TTL for specified session id. // This function is called ever after session, which is not dirty, is closed. // It just adds the session id to the async handling queue. -func (s *StorageFile) UpdateTTL(id string, ttl time.Duration) error { - intlog.Printf("StorageFile.UpdateTTL: %s, %v", id, ttl) +func (s *StorageFile) UpdateTTL(ctx context.Context, id string, ttl time.Duration) error { + intlog.Printf(ctx, "StorageFile.UpdateTTL: %s, %v", id, ttl) if ttl >= DefaultStorageFileLoopInterval { s.updatingIdSet.Add(id) } @@ -231,8 +231,8 @@ func (s *StorageFile) UpdateTTL(id string, ttl time.Duration) error { } // updateSessionTTL updates the TTL for specified session id. -func (s *StorageFile) updateSessionTTl(id string) error { - intlog.Printf("StorageFile.updateSession: %s", id) +func (s *StorageFile) updateSessionTTl(ctx context.Context, id string) error { + intlog.Printf(ctx, "StorageFile.updateSession: %s", id) path := s.sessionFilePath(id) file, err := gfile.OpenWithFlag(path, os.O_WRONLY) if err != nil { diff --git a/os/gsession/gsession_storage_memory.go b/os/gsession/gsession_storage_memory.go index 7b41e7ff6..26f217b92 100644 --- a/os/gsession/gsession_storage_memory.go +++ b/os/gsession/gsession_storage_memory.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "github.com/gogf/gf/container/gmap" "time" ) @@ -21,45 +22,45 @@ func NewStorageMemory() *StorageMemory { // New creates a session id. // This function can be used for custom session creation. -func (s *StorageMemory) New(ttl time.Duration) (id string) { +func (s *StorageMemory) New(ctx context.Context, ttl time.Duration) (id string) { return "" } // Get retrieves session value with given key. // It returns nil if the key does not exist in the session. -func (s *StorageMemory) Get(id string, key string) interface{} { +func (s *StorageMemory) Get(ctx context.Context, id string, key string) interface{} { return nil } // GetMap retrieves all key-value pairs as map from storage. -func (s *StorageMemory) GetMap(id string) map[string]interface{} { +func (s *StorageMemory) GetMap(ctx context.Context, id string) map[string]interface{} { return nil } // GetSize retrieves the size of key-value pairs from storage. -func (s *StorageMemory) GetSize(id string) int { +func (s *StorageMemory) GetSize(ctx context.Context, id string) int { return -1 } // Set sets key-value session pair to the storage. // The parameter <ttl> specifies the TTL for the session id (not for the key-value pair). -func (s *StorageMemory) Set(id string, key string, value interface{}, ttl time.Duration) error { +func (s *StorageMemory) Set(ctx context.Context, id string, key string, value interface{}, ttl time.Duration) error { return ErrorDisabled } // SetMap batch sets key-value session pairs with map to the storage. // The parameter <ttl> specifies the TTL for the session id(not for the key-value pair). -func (s *StorageMemory) SetMap(id string, data map[string]interface{}, ttl time.Duration) error { +func (s *StorageMemory) SetMap(ctx context.Context, id string, data map[string]interface{}, ttl time.Duration) error { return ErrorDisabled } // Remove deletes key with its value from storage. -func (s *StorageMemory) Remove(id string, key string) error { +func (s *StorageMemory) Remove(ctx context.Context, id string, key string) error { return ErrorDisabled } // RemoveAll deletes all key-value pairs from storage. -func (s *StorageMemory) RemoveAll(id string) error { +func (s *StorageMemory) RemoveAll(ctx context.Context, id string) error { return ErrorDisabled } @@ -70,25 +71,20 @@ func (s *StorageMemory) RemoveAll(id string) error { // and for some storage it might be nil if memory storage is disabled. // // This function is called ever when session starts. -func (s *StorageMemory) GetSession(id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { +func (s *StorageMemory) GetSession(ctx context.Context, id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { return data, nil } // SetSession updates the data map for specified session id. // This function is called ever after session, which is changed dirty, is closed. // This copy all session data map from memory to storage. -func (s *StorageMemory) SetSession(id string, data *gmap.StrAnyMap, ttl time.Duration) error { +func (s *StorageMemory) SetSession(ctx context.Context, id string, data *gmap.StrAnyMap, ttl time.Duration) error { return nil } // UpdateTTL updates the TTL for specified session id. // This function is called ever after session, which is not dirty, is closed. // It just adds the session id to the async handling queue. -func (s *StorageMemory) UpdateTTL(id string, ttl time.Duration) error { - return nil -} - -// doUpdateTTL updates the TTL for session id. -func (s *StorageMemory) doUpdateTTL(id string) error { +func (s *StorageMemory) UpdateTTL(ctx context.Context, id string, ttl time.Duration) error { return nil } diff --git a/os/gsession/gsession_storage_redis.go b/os/gsession/gsession_storage_redis.go index 8d26e2db6..a8b8016af 100644 --- a/os/gsession/gsession_storage_redis.go +++ b/os/gsession/gsession_storage_redis.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "github.com/gogf/gf/container/gmap" "github.com/gogf/gf/database/gredis" "github.com/gogf/gf/internal/intlog" @@ -44,7 +45,7 @@ func NewStorageRedis(redis *gredis.Redis, prefix ...string) *StorageRedis { } // Batch updates the TTL for session ids timely. gtimer.AddSingleton(DefaultStorageRedisLoopInterval, func() { - intlog.Print("StorageRedis.timer start") + intlog.Print(context.TODO(), "StorageRedis.timer start") var ( id string err error @@ -54,57 +55,57 @@ func NewStorageRedis(redis *gredis.Redis, prefix ...string) *StorageRedis { if id, ttlSeconds = s.updatingIdMap.Pop(); id == "" { break } else { - if err = s.doUpdateTTL(id, ttlSeconds); err != nil { - intlog.Error(err) + if err = s.doUpdateTTL(context.TODO(), id, ttlSeconds); err != nil { + intlog.Error(context.TODO(), err) } } } - intlog.Print("StorageRedis.timer end") + intlog.Print(context.TODO(), "StorageRedis.timer end") }) return s } // New creates a session id. // This function can be used for custom session creation. -func (s *StorageRedis) New(ttl time.Duration) (id string) { +func (s *StorageRedis) New(ctx context.Context, ttl time.Duration) (id string) { return "" } // Get retrieves session value with given key. // It returns nil if the key does not exist in the session. -func (s *StorageRedis) Get(id string, key string) interface{} { +func (s *StorageRedis) Get(ctx context.Context, id string, key string) interface{} { return nil } // GetMap retrieves all key-value pairs as map from storage. -func (s *StorageRedis) GetMap(id string) map[string]interface{} { +func (s *StorageRedis) GetMap(ctx context.Context, id string) map[string]interface{} { return nil } // GetSize retrieves the size of key-value pairs from storage. -func (s *StorageRedis) GetSize(id string) int { +func (s *StorageRedis) GetSize(ctx context.Context, id string) int { return -1 } // Set sets key-value session pair to the storage. // The parameter <ttl> specifies the TTL for the session id (not for the key-value pair). -func (s *StorageRedis) Set(id string, key string, value interface{}, ttl time.Duration) error { +func (s *StorageRedis) Set(ctx context.Context, id string, key string, value interface{}, ttl time.Duration) error { return ErrorDisabled } // SetMap batch sets key-value session pairs with map to the storage. // The parameter <ttl> specifies the TTL for the session id(not for the key-value pair). -func (s *StorageRedis) SetMap(id string, data map[string]interface{}, ttl time.Duration) error { +func (s *StorageRedis) SetMap(ctx context.Context, id string, data map[string]interface{}, ttl time.Duration) error { return ErrorDisabled } // Remove deletes key with its value from storage. -func (s *StorageRedis) Remove(id string, key string) error { +func (s *StorageRedis) Remove(ctx context.Context, id string, key string) error { return ErrorDisabled } // RemoveAll deletes all key-value pairs from storage. -func (s *StorageRedis) RemoveAll(id string) error { +func (s *StorageRedis) RemoveAll(ctx context.Context, id string) error { return ErrorDisabled } @@ -115,9 +116,9 @@ func (s *StorageRedis) RemoveAll(id string) error { // and for some storage it might be nil if memory storage is disabled. // // This function is called ever when session starts. -func (s *StorageRedis) GetSession(id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { - intlog.Printf("StorageRedis.GetSession: %s, %v", id, ttl) - r, err := s.redis.DoVar("GET", s.key(id)) +func (s *StorageRedis) GetSession(ctx context.Context, id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { + intlog.Printf(ctx, "StorageRedis.GetSession: %s, %v", id, ttl) + r, err := s.redis.Ctx(ctx).DoVar("GET", s.key(id)) if err != nil { return nil, err } @@ -143,21 +144,21 @@ func (s *StorageRedis) GetSession(id string, ttl time.Duration, data *gmap.StrAn // SetSession updates the data map for specified session id. // This function is called ever after session, which is changed dirty, is closed. // This copy all session data map from memory to storage. -func (s *StorageRedis) SetSession(id string, data *gmap.StrAnyMap, ttl time.Duration) error { - intlog.Printf("StorageRedis.SetSession: %s, %v, %v", id, data, ttl) +func (s *StorageRedis) SetSession(ctx context.Context, id string, data *gmap.StrAnyMap, ttl time.Duration) error { + intlog.Printf(ctx, "StorageRedis.SetSession: %s, %v, %v", id, data, ttl) content, err := json.Marshal(data) if err != nil { return err } - _, err = s.redis.DoVar("SETEX", s.key(id), int64(ttl.Seconds()), content) + _, err = s.redis.Ctx(ctx).DoVar("SETEX", s.key(id), int64(ttl.Seconds()), content) return err } // UpdateTTL updates the TTL for specified session id. // This function is called ever after session, which is not dirty, is closed. // It just adds the session id to the async handling queue. -func (s *StorageRedis) UpdateTTL(id string, ttl time.Duration) error { - intlog.Printf("StorageRedis.UpdateTTL: %s, %v", id, ttl) +func (s *StorageRedis) UpdateTTL(ctx context.Context, id string, ttl time.Duration) error { + intlog.Printf(ctx, "StorageRedis.UpdateTTL: %s, %v", id, ttl) if ttl >= DefaultStorageRedisLoopInterval { s.updatingIdMap.Set(id, int(ttl.Seconds())) } @@ -165,9 +166,9 @@ func (s *StorageRedis) UpdateTTL(id string, ttl time.Duration) error { } // doUpdateTTL updates the TTL for session id. -func (s *StorageRedis) doUpdateTTL(id string, ttlSeconds int) error { - intlog.Printf("StorageRedis.doUpdateTTL: %s, %d", id, ttlSeconds) - _, err := s.redis.DoVar("EXPIRE", s.key(id), ttlSeconds) +func (s *StorageRedis) doUpdateTTL(ctx context.Context, id string, ttlSeconds int) error { + intlog.Printf(ctx, "StorageRedis.doUpdateTTL: %s, %d", id, ttlSeconds) + _, err := s.redis.Ctx(ctx).DoVar("EXPIRE", s.key(id), ttlSeconds) return err } diff --git a/os/gsession/gsession_storage_redis_hashtable.go b/os/gsession/gsession_storage_redis_hashtable.go index 15eb825a3..352e5f8b7 100644 --- a/os/gsession/gsession_storage_redis_hashtable.go +++ b/os/gsession/gsession_storage_redis_hashtable.go @@ -7,6 +7,7 @@ package gsession import ( + "context" "time" "github.com/gogf/gf/container/gmap" @@ -38,14 +39,14 @@ func NewStorageRedisHashTable(redis *gredis.Redis, prefix ...string) *StorageRed // New creates a session id. // This function can be used for custom session creation. -func (s *StorageRedisHashTable) New(ttl time.Duration) (id string) { +func (s *StorageRedisHashTable) New(ctx context.Context, ttl time.Duration) (id string) { return "" } // Get retrieves session value with given key. // It returns nil if the key does not exist in the session. -func (s *StorageRedisHashTable) Get(id string, key string) interface{} { - r, _ := s.redis.Do("HGET", s.key(id), key) +func (s *StorageRedisHashTable) Get(ctx context.Context, id string, key string) interface{} { + r, _ := s.redis.Ctx(ctx).Do("HGET", s.key(id), key) if r != nil { return gconv.String(r) } @@ -53,8 +54,8 @@ func (s *StorageRedisHashTable) Get(id string, key string) interface{} { } // GetMap retrieves all key-value pairs as map from storage. -func (s *StorageRedisHashTable) GetMap(id string) map[string]interface{} { - r, err := s.redis.DoVar("HGETALL", s.key(id)) +func (s *StorageRedisHashTable) GetMap(ctx context.Context, id string) map[string]interface{} { + r, err := s.redis.Ctx(ctx).DoVar("HGETALL", s.key(id)) if err != nil { return nil } @@ -71,21 +72,21 @@ func (s *StorageRedisHashTable) GetMap(id string) map[string]interface{} { } // GetSize retrieves the size of key-value pairs from storage. -func (s *StorageRedisHashTable) GetSize(id string) int { - r, _ := s.redis.DoVar("HLEN", s.key(id)) +func (s *StorageRedisHashTable) GetSize(ctx context.Context, id string) int { + r, _ := s.redis.Ctx(ctx).DoVar("HLEN", s.key(id)) return r.Int() } // Set sets key-value session pair to the storage. // The parameter <ttl> specifies the TTL for the session id (not for the key-value pair). -func (s *StorageRedisHashTable) Set(id string, key string, value interface{}, ttl time.Duration) error { - _, err := s.redis.Do("HSET", s.key(id), key, value) +func (s *StorageRedisHashTable) Set(ctx context.Context, id string, key string, value interface{}, ttl time.Duration) error { + _, err := s.redis.Ctx(ctx).Do("HSET", s.key(id), key, value) return err } // SetMap batch sets key-value session pairs with map to the storage. // The parameter <ttl> specifies the TTL for the session id(not for the key-value pair). -func (s *StorageRedisHashTable) SetMap(id string, data map[string]interface{}, ttl time.Duration) error { +func (s *StorageRedisHashTable) SetMap(ctx context.Context, id string, data map[string]interface{}, ttl time.Duration) error { array := make([]interface{}, len(data)*2+1) array[0] = s.key(id) @@ -95,19 +96,19 @@ func (s *StorageRedisHashTable) SetMap(id string, data map[string]interface{}, t array[index+1] = v index += 2 } - _, err := s.redis.Do("HMSET", array...) + _, err := s.redis.Ctx(ctx).Do("HMSET", array...) return err } // Remove deletes key with its value from storage. -func (s *StorageRedisHashTable) Remove(id string, key string) error { - _, err := s.redis.Do("HDEL", s.key(id), key) +func (s *StorageRedisHashTable) Remove(ctx context.Context, id string, key string) error { + _, err := s.redis.Ctx(ctx).Do("HDEL", s.key(id), key) return err } // RemoveAll deletes all key-value pairs from storage. -func (s *StorageRedisHashTable) RemoveAll(id string) error { - _, err := s.redis.Do("DEL", s.key(id)) +func (s *StorageRedisHashTable) RemoveAll(ctx context.Context, id string) error { + _, err := s.redis.Ctx(ctx).Do("DEL", s.key(id)) return err } @@ -118,9 +119,9 @@ func (s *StorageRedisHashTable) RemoveAll(id string) error { // and for some storage it might be nil if memory storage is disabled. // // This function is called ever when session starts. -func (s *StorageRedisHashTable) GetSession(id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { - intlog.Printf("StorageRedisHashTable.GetSession: %s, %v", id, ttl) - r, err := s.redis.DoVar("EXISTS", s.key(id)) +func (s *StorageRedisHashTable) GetSession(ctx context.Context, id string, ttl time.Duration, data *gmap.StrAnyMap) (*gmap.StrAnyMap, error) { + intlog.Printf(ctx, "StorageRedisHashTable.GetSession: %s, %v", id, ttl) + r, err := s.redis.Ctx(ctx).DoVar("EXISTS", s.key(id)) if err != nil { return nil, err } @@ -133,17 +134,17 @@ func (s *StorageRedisHashTable) GetSession(id string, ttl time.Duration, data *g // SetSession updates the data map for specified session id. // This function is called ever after session, which is changed dirty, is closed. // This copy all session data map from memory to storage. -func (s *StorageRedisHashTable) SetSession(id string, data *gmap.StrAnyMap, ttl time.Duration) error { - intlog.Printf("StorageRedisHashTable.SetSession: %s, %v", id, ttl) - _, err := s.redis.Do("EXPIRE", s.key(id), int64(ttl.Seconds())) +func (s *StorageRedisHashTable) SetSession(ctx context.Context, id string, data *gmap.StrAnyMap, ttl time.Duration) error { + intlog.Printf(ctx, "StorageRedisHashTable.SetSession: %s, %v", id, ttl) + _, err := s.redis.Ctx(ctx).Do("EXPIRE", s.key(id), int64(ttl.Seconds())) return err } // UpdateTTL updates the TTL for specified session id. // This function is called ever after session, which is not dirty, is closed. // It just adds the session id to the async handling queue. -func (s *StorageRedisHashTable) UpdateTTL(id string, ttl time.Duration) error { - intlog.Printf("StorageRedisHashTable.UpdateTTL: %s, %v", id, ttl) +func (s *StorageRedisHashTable) UpdateTTL(ctx context.Context, id string, ttl time.Duration) error { + intlog.Printf(ctx, "StorageRedisHashTable.UpdateTTL: %s, %v", id, ttl) _, err := s.redis.Do("EXPIRE", s.key(id), int64(ttl.Seconds())) return err } diff --git a/os/gsession/gsession_unit_storage_file_test.go b/os/gsession/gsession_unit_storage_file_test.go index edd872b06..3a7c7ebde 100644 --- a/os/gsession/gsession_unit_storage_file_test.go +++ b/os/gsession/gsession_unit_storage_file_test.go @@ -7,6 +7,7 @@ package gsession_test import ( + "context" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gsession" "testing" @@ -20,7 +21,7 @@ func Test_StorageFile(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -34,7 +35,7 @@ func Test_StorageFile(t *testing.T) { time.Sleep(500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -66,7 +67,7 @@ func Test_StorageFile(t *testing.T) { time.Sleep(1000 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) diff --git a/os/gsession/gsession_unit_storage_memory_test.go b/os/gsession/gsession_unit_storage_memory_test.go index 054e9948d..c42413d34 100644 --- a/os/gsession/gsession_unit_storage_memory_test.go +++ b/os/gsession/gsession_unit_storage_memory_test.go @@ -7,6 +7,7 @@ package gsession_test import ( + "context" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gsession" "testing" @@ -20,7 +21,7 @@ func Test_StorageMemory(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -34,7 +35,7 @@ func Test_StorageMemory(t *testing.T) { time.Sleep(500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -66,7 +67,7 @@ func Test_StorageMemory(t *testing.T) { time.Sleep(1000 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) diff --git a/os/gsession/gsession_unit_storage_redis_hashtable_test.go b/os/gsession/gsession_unit_storage_redis_hashtable_test.go index 6ce58ef38..20d7ff082 100644 --- a/os/gsession/gsession_unit_storage_redis_hashtable_test.go +++ b/os/gsession/gsession_unit_storage_redis_hashtable_test.go @@ -7,6 +7,7 @@ package gsession_test import ( + "context" "github.com/gogf/gf/database/gredis" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gsession" @@ -26,7 +27,7 @@ func Test_StorageRedisHashTable(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -38,7 +39,7 @@ func Test_StorageRedisHashTable(t *testing.T) { sessionId = s.Id() }) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -71,7 +72,7 @@ func Test_StorageRedisHashTable(t *testing.T) { time.Sleep(1500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) @@ -89,7 +90,7 @@ func Test_StorageRedisHashTablePrefix(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -101,7 +102,7 @@ func Test_StorageRedisHashTablePrefix(t *testing.T) { sessionId = s.Id() }) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -134,7 +135,7 @@ func Test_StorageRedisHashTablePrefix(t *testing.T) { time.Sleep(1500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) diff --git a/os/gsession/gsession_unit_storage_redis_test.go b/os/gsession/gsession_unit_storage_redis_test.go index f0fd9f3b2..53a5e754b 100644 --- a/os/gsession/gsession_unit_storage_redis_test.go +++ b/os/gsession/gsession_unit_storage_redis_test.go @@ -7,6 +7,7 @@ package gsession_test import ( + "context" "github.com/gogf/gf/database/gredis" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gsession" @@ -24,7 +25,7 @@ func Test_StorageRedis(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -38,7 +39,7 @@ func Test_StorageRedis(t *testing.T) { time.Sleep(500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -70,7 +71,7 @@ func Test_StorageRedis(t *testing.T) { time.Sleep(1000 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) @@ -86,7 +87,7 @@ func Test_StorageRedisPrefix(t *testing.T) { manager := gsession.New(time.Second, storage) sessionId := "" gtest.C(t, func(t *gtest.T) { - s := manager.New() + s := manager.New(context.TODO()) defer s.Close() s.Set("k1", "v1") s.Set("k2", "v2") @@ -100,7 +101,7 @@ func Test_StorageRedisPrefix(t *testing.T) { time.Sleep(500 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Get("k1"), "v1") t.Assert(s.Get("k2"), "v2") t.Assert(s.Get("k3"), "v3") @@ -132,7 +133,7 @@ func Test_StorageRedisPrefix(t *testing.T) { time.Sleep(1000 * time.Millisecond) gtest.C(t, func(t *gtest.T) { - s := manager.New(sessionId) + s := manager.New(context.TODO(), sessionId) t.Assert(s.Size(), 0) t.Assert(s.Get("k5"), nil) t.Assert(s.Get("k6"), nil) diff --git a/os/gspath/gspath.go b/os/gspath/gspath.go index d17ce2658..e54300629 100644 --- a/os/gspath/gspath.go +++ b/os/gspath/gspath.go @@ -12,8 +12,8 @@ package gspath import ( - "errors" - "fmt" + "context" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "os" "sort" @@ -103,7 +103,7 @@ func (sp *SPath) Set(path string) (realPath string, err error) { } } if realPath == "" { - return realPath, errors.New(fmt.Sprintf(`path "%s" does not exist`, path)) + return realPath, gerror.Newf(`path "%s" does not exist`, path) } // The set path must be a directory. if gfile.IsDir(realPath) { @@ -113,7 +113,7 @@ func (sp *SPath) Set(path string) (realPath string, err error) { sp.removeMonitorByPath(v) } } - intlog.Print("paths clear:", sp.paths) + intlog.Print(context.TODO(), "paths clear:", sp.paths) sp.paths.Clear() if sp.cache != nil { sp.cache.Clear() @@ -123,7 +123,7 @@ func (sp *SPath) Set(path string) (realPath string, err error) { sp.addMonitorByPath(realPath) return realPath, nil } else { - return "", errors.New(path + " should be a folder") + return "", gerror.New(path + " should be a folder") } } @@ -138,7 +138,7 @@ func (sp *SPath) Add(path string) (realPath string, err error) { } } if realPath == "" { - return realPath, errors.New(fmt.Sprintf(`path "%s" does not exist`, path)) + return realPath, gerror.Newf(`path "%s" does not exist`, path) } // The added path must be a directory. if gfile.IsDir(realPath) { @@ -152,7 +152,7 @@ func (sp *SPath) Add(path string) (realPath string, err error) { } return realPath, nil } else { - return "", errors.New(path + " should be a folder") + return "", gerror.New(path + " should be a folder") } } diff --git a/os/gtimer/gtimer.go b/os/gtimer/gtimer.go index 3ffd36326..5cc849f7b 100644 --- a/os/gtimer/gtimer.go +++ b/os/gtimer/gtimer.go @@ -19,7 +19,6 @@ package gtimer import ( - "fmt" "github.com/gogf/gf/container/gtype" "math" "sync" @@ -43,21 +42,19 @@ type TimerOptions struct { } const ( - StatusReady = 0 // Job or Timer is ready for running. - StatusRunning = 1 // Job or Timer is already running. - StatusStopped = 2 // Job or Timer is stopped. - StatusClosed = -1 // Job or Timer is closed and waiting to be deleted. - panicExit = "exit" // panicExit is used for custom job exit with panic. - defaultTimes = math.MaxInt32 // defaultTimes is the default limit running times, a big number. - defaultTimerInterval = 100 // defaultTimerInterval is the default timer interval in milliseconds. - cmdEnvKey = "gf.gtimer" // Configuration key for command argument or environment. + StatusReady = 0 // Job or Timer is ready for running. + StatusRunning = 1 // Job or Timer is already running. + StatusStopped = 2 // Job or Timer is stopped. + StatusClosed = -1 // Job or Timer is closed and waiting to be deleted. + panicExit = "exit" // panicExit is used for custom job exit with panic. + defaultTimes = math.MaxInt32 // defaultTimes is the default limit running times, a big number. + defaultTimerInterval = 100 // defaultTimerInterval is the default timer interval in milliseconds. + commandEnvKeyForInterval = "gf.gtimer.interval" // commandEnvKeyForInterval is the key for command argument or environment configuring default interval duration for timer. ) var ( defaultTimer = New() - defaultInterval = gcmd.GetOptWithEnv( - fmt.Sprintf("%s.interval", cmdEnvKey), defaultTimerInterval, - ).Duration() * time.Millisecond + defaultInterval = gcmd.GetOptWithEnv(commandEnvKeyForInterval, defaultTimerInterval).Duration() * time.Millisecond ) // DefaultOptions creates and returns a default options object for Timer creation. diff --git a/os/gview/gview.go b/os/gview/gview.go index dc0013add..1b072f6e3 100644 --- a/os/gview/gview.go +++ b/os/gview/gview.go @@ -36,6 +36,10 @@ type ( FuncMap = map[string]interface{} // FuncMap is type for custom template functions. ) +const ( + commandEnvKeyForPath = "gf.gview.path" +) + var ( // Default view object. defaultViewObj *View @@ -68,14 +72,14 @@ func New(path ...string) *View { } if len(path) > 0 && len(path[0]) > 0 { if err := view.SetPath(path[0]); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } else { // Customized dir path from env/cmd. - if envPath := gcmd.GetOptWithEnv("gf.gview.path").String(); envPath != "" { + if envPath := gcmd.GetOptWithEnv(commandEnvKeyForPath).String(); envPath != "" { if gfile.Exists(envPath) { if err := view.SetPath(envPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } else { if errorPrint() { @@ -85,18 +89,18 @@ func New(path ...string) *View { } else { // Dir path of working dir. if err := view.SetPath(gfile.Pwd()); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } // Dir path of binary. if selfPath := gfile.SelfDir(); selfPath != "" && gfile.Exists(selfPath) { if err := view.AddPath(selfPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } // Dir path of main package. if mainPath := gfile.MainPkgPath(); mainPath != "" && gfile.Exists(mainPath) { if err := view.AddPath(mainPath); err != nil { - intlog.Error(err) + intlog.Error(context.TODO(), err) } } } diff --git a/os/gview/gview_config.go b/os/gview/gview_config.go index 508344a0a..e596b07bf 100644 --- a/os/gview/gview_config.go +++ b/os/gview/gview_config.go @@ -7,8 +7,8 @@ package gview import ( - "errors" - "fmt" + "context" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/i18n/gi18n" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" @@ -67,14 +67,14 @@ func (view *View) SetConfig(config Config) error { // It's just cache, do not hesitate clearing it. templates.Clear() - intlog.Printf("SetConfig: %+v", view.config) + intlog.Printf(context.TODO(), "SetConfig: %+v", view.config) return nil } // SetConfigWithMap set configurations with map for the view. func (view *View) SetConfigWithMap(m map[string]interface{}) error { if m == nil || len(m) == 0 { - return errors.New("configuration cannot be empty") + return gerror.New("configuration cannot be empty") } // The m now is a shallow copy of m. // Any changes to m does not affect the original one. @@ -123,7 +123,7 @@ func (view *View) SetPath(path string) error { } // Path not exist. if realPath == "" { - err := errors.New(fmt.Sprintf(`[gview] SetPath failed: path "%s" does not exist`, path)) + err := gerror.Newf(`[gview] SetPath failed: path "%s" does not exist`, path) if errorPrint() { glog.Error(err) } @@ -131,7 +131,7 @@ func (view *View) SetPath(path string) error { } // Should be a directory. if !isDir { - err := errors.New(fmt.Sprintf(`[gview] SetPath failed: path "%s" should be directory type`, path)) + err := gerror.Newf(`[gview] SetPath failed: path "%s" should be directory type`, path) if errorPrint() { glog.Error(err) } @@ -177,7 +177,7 @@ func (view *View) AddPath(path string) error { } // Path not exist. if realPath == "" { - err := errors.New(fmt.Sprintf(`[gview] AddPath failed: path "%s" does not exist`, path)) + err := gerror.Newf(`[gview] AddPath failed: path "%s" does not exist`, path) if errorPrint() { glog.Error(err) } @@ -185,7 +185,7 @@ func (view *View) AddPath(path string) error { } // realPath should be type of folder. if !isDir { - err := errors.New(fmt.Sprintf(`[gview] AddPath failed: path "%s" should be directory type`, path)) + err := gerror.Newf(`[gview] AddPath failed: path "%s" should be directory type`, path) if errorPrint() { glog.Error(err) } diff --git a/os/gview/gview_error.go b/os/gview/gview_error.go index 4c38fafe5..268321bf7 100644 --- a/os/gview/gview_error.go +++ b/os/gview/gview_error.go @@ -11,12 +11,12 @@ import ( ) const ( - // gERROR_PRINT_KEY is used to specify the key controlling error printing to stdout. + // commandEnvKeyForErrorPrint is used to specify the key controlling error printing to stdout. // This error is designed not to be returned by functions. - gERROR_PRINT_KEY = "gf.gview.errorprint" + commandEnvKeyForErrorPrint = "gf.gview.errorprint" ) // errorPrint checks whether printing error to stdout. func errorPrint() bool { - return gcmd.GetOptWithEnv(gERROR_PRINT_KEY, true).Bool() + return gcmd.GetOptWithEnv(commandEnvKeyForErrorPrint, true).Bool() } diff --git a/os/gview/gview_parse.go b/os/gview/gview_parse.go index f69d7e08b..ef2e242b9 100644 --- a/os/gview/gview_parse.go +++ b/os/gview/gview_parse.go @@ -9,7 +9,6 @@ package gview import ( "bytes" "context" - "errors" "fmt" "github.com/gogf/gf/encoding/ghash" "github.com/gogf/gf/errors/gerror" @@ -83,7 +82,7 @@ func (view *View) Parse(ctx context.Context, file string, params ...Params) (res templates.Clear() gfsnotify.Exit() }); err != nil { - intlog.Error(err) + intlog.Error(ctx, err) } } return &fileCacheItem{ @@ -377,7 +376,7 @@ func (view *View) searchFile(file string) (path string, folder string, resource if errorPrint() { glog.Error(buffer.String()) } - err = errors.New(fmt.Sprintf(`template file "%s" not found`, file)) + err = gerror.Newf(`template file "%s" not found`, file) } return } diff --git a/util/gconv/gconv.go b/util/gconv/gconv.go index be6b027ac..77e492fe8 100644 --- a/util/gconv/gconv.go +++ b/util/gconv/gconv.go @@ -46,238 +46,257 @@ var ( StructTagPriority = []string{"gconv", "param", "params", "c", "p", "json"} ) -// Convert converts the variable `any` to the type `t`, the type `t` is specified by string. -// The optional parameter `params` is used for additional necessary parameter for this conversion. -// It supports common types conversion as its conversion based on type name string. -func Convert(any interface{}, t string, params ...interface{}) interface{} { - switch t { +type doConvertInput struct { + FromValue interface{} // Value that is converted from. + ToTypeName string // Target value type name in string. + ReferValue interface{} // Referred value, a value in type `ToTypeName`. + Extra []interface{} // Extra values for implementing the converting. +} + +// doConvert does common used types converting. +func doConvert(input doConvertInput) interface{} { + switch input.ToTypeName { case "int": - return Int(any) + return Int(input.FromValue) case "*int": - if _, ok := any.(*int); ok { - return any + if _, ok := input.FromValue.(*int); ok { + return input.FromValue } - v := Int(any) + v := Int(input.FromValue) return &v case "int8": - return Int8(any) + return Int8(input.FromValue) case "*int8": - if _, ok := any.(*int8); ok { - return any + if _, ok := input.FromValue.(*int8); ok { + return input.FromValue } - v := Int8(any) + v := Int8(input.FromValue) return &v case "int16": - return Int16(any) + return Int16(input.FromValue) case "*int16": - if _, ok := any.(*int16); ok { - return any + if _, ok := input.FromValue.(*int16); ok { + return input.FromValue } - v := Int16(any) + v := Int16(input.FromValue) return &v case "int32": - return Int32(any) + return Int32(input.FromValue) case "*int32": - if _, ok := any.(*int32); ok { - return any + if _, ok := input.FromValue.(*int32); ok { + return input.FromValue } - v := Int32(any) + v := Int32(input.FromValue) return &v case "int64": - return Int64(any) + return Int64(input.FromValue) case "*int64": - if _, ok := any.(*int64); ok { - return any + if _, ok := input.FromValue.(*int64); ok { + return input.FromValue } - v := Int64(any) + v := Int64(input.FromValue) return &v case "uint": - return Uint(any) + return Uint(input.FromValue) case "*uint": - if _, ok := any.(*uint); ok { - return any + if _, ok := input.FromValue.(*uint); ok { + return input.FromValue } - v := Uint(any) + v := Uint(input.FromValue) return &v case "uint8": - return Uint8(any) + return Uint8(input.FromValue) case "*uint8": - if _, ok := any.(*uint8); ok { - return any + if _, ok := input.FromValue.(*uint8); ok { + return input.FromValue } - v := Uint8(any) + v := Uint8(input.FromValue) return &v case "uint16": - return Uint16(any) + return Uint16(input.FromValue) case "*uint16": - if _, ok := any.(*uint16); ok { - return any + if _, ok := input.FromValue.(*uint16); ok { + return input.FromValue } - v := Uint16(any) + v := Uint16(input.FromValue) return &v case "uint32": - return Uint32(any) + return Uint32(input.FromValue) case "*uint32": - if _, ok := any.(*uint32); ok { - return any + if _, ok := input.FromValue.(*uint32); ok { + return input.FromValue } - v := Uint32(any) + v := Uint32(input.FromValue) return &v case "uint64": - return Uint64(any) + return Uint64(input.FromValue) case "*uint64": - if _, ok := any.(*uint64); ok { - return any + if _, ok := input.FromValue.(*uint64); ok { + return input.FromValue } - v := Uint64(any) + v := Uint64(input.FromValue) return &v case "float32": - return Float32(any) + return Float32(input.FromValue) case "*float32": - if _, ok := any.(*float32); ok { - return any + if _, ok := input.FromValue.(*float32); ok { + return input.FromValue } - v := Float32(any) + v := Float32(input.FromValue) return &v case "float64": - return Float64(any) + return Float64(input.FromValue) case "*float64": - if _, ok := any.(*float64); ok { - return any + if _, ok := input.FromValue.(*float64); ok { + return input.FromValue } - v := Float64(any) + v := Float64(input.FromValue) return &v case "bool": - return Bool(any) + return Bool(input.FromValue) case "*bool": - if _, ok := any.(*bool); ok { - return any + if _, ok := input.FromValue.(*bool); ok { + return input.FromValue } - v := Bool(any) + v := Bool(input.FromValue) return &v case "string": - return String(any) + return String(input.FromValue) case "*string": - if _, ok := any.(*string); ok { - return any + if _, ok := input.FromValue.(*string); ok { + return input.FromValue } - v := String(any) + v := String(input.FromValue) return &v case "[]byte": - return Bytes(any) + return Bytes(input.FromValue) case "[]int": - return Ints(any) + return Ints(input.FromValue) case "[]int32": - return Int32s(any) + return Int32s(input.FromValue) case "[]int64": - return Int64s(any) + return Int64s(input.FromValue) case "[]uint": - return Uints(any) + return Uints(input.FromValue) + case "[]uint8": + return Bytes(input.FromValue) case "[]uint32": - return Uint32s(any) + return Uint32s(input.FromValue) case "[]uint64": - return Uint64s(any) + return Uint64s(input.FromValue) case "[]float32": - return Float32s(any) + return Float32s(input.FromValue) case "[]float64": - return Float64s(any) + return Float64s(input.FromValue) case "[]string": - return Strings(any) + return Strings(input.FromValue) case "Time", "time.Time": - if len(params) > 0 { - return Time(any, String(params[0])) + if len(input.Extra) > 0 { + return Time(input.FromValue, String(input.Extra[0])) } - return Time(any) + return Time(input.FromValue) case "*time.Time": var v interface{} - if len(params) > 0 { - v = Time(any, String(params[0])) + if len(input.Extra) > 0 { + v = Time(input.FromValue, String(input.Extra[0])) } else { - if _, ok := any.(*time.Time); ok { - return any + if _, ok := input.FromValue.(*time.Time); ok { + return input.FromValue } - v = Time(any) + v = Time(input.FromValue) } return &v case "GTime", "gtime.Time": - if len(params) > 0 { - if v := GTime(any, String(params[0])); v != nil { + if len(input.Extra) > 0 { + if v := GTime(input.FromValue, String(input.Extra[0])); v != nil { return *v } else { return *gtime.New() } } - if v := GTime(any); v != nil { + if v := GTime(input.FromValue); v != nil { return *v } else { return *gtime.New() } case "*gtime.Time": - if len(params) > 0 { - if v := GTime(any, String(params[0])); v != nil { + if len(input.Extra) > 0 { + if v := GTime(input.FromValue, String(input.Extra[0])); v != nil { return v } else { return gtime.New() } } - if v := GTime(any); v != nil { + if v := GTime(input.FromValue); v != nil { return v } else { return gtime.New() } case "Duration", "time.Duration": - return Duration(any) + return Duration(input.FromValue) case "*time.Duration": - if _, ok := any.(*time.Duration); ok { - return any + if _, ok := input.FromValue.(*time.Duration); ok { + return input.FromValue } - v := Duration(any) + v := Duration(input.FromValue) return &v case "map[string]string": - return MapStrStr(any) + return MapStrStr(input.FromValue) case "map[string]interface{}": - return Map(any) + return Map(input.FromValue) case "[]map[string]interface{}": - return Maps(any) - - //case "gvar.Var": - // // TODO remove reflect usage to create gvar.Var, considering using unsafe pointer - // rv := reflect.New(intstore.ReflectTypeVarImp) - // ri := rv.Interface() - // if v, ok := ri.(apiSet); ok { - // v.Set(any) - // } else if v, ok := ri.(apiUnmarshalValue); ok { - // v.UnmarshalValue(any) - // } else { - // rv.Set(reflect.ValueOf(any)) - // } - // return ri + return Maps(input.FromValue) default: - return any + if input.ReferValue != nil { + var ( + referReflectValue reflect.Value + ) + if v, ok := input.ReferValue.(reflect.Value); ok { + referReflectValue = v + } else { + referReflectValue = reflect.ValueOf(input.ReferValue) + } + input.ToTypeName = referReflectValue.Kind().String() + input.ReferValue = nil + return reflect.ValueOf(doConvert(input)).Convert(referReflectValue.Type()).Interface() + } + return input.FromValue } } +// Convert converts the variable `fromValue` to the type `toTypeName`, the type `toTypeName` is specified by string. +// The optional parameter `extraParams` is used for additional necessary parameter for this conversion. +// It supports common types conversion as its conversion based on type name string. +func Convert(fromValue interface{}, toTypeName string, extraParams ...interface{}) interface{} { + return doConvert(doConvertInput{ + FromValue: fromValue, + ToTypeName: toTypeName, + ReferValue: nil, + Extra: extraParams, + }) +} + // Byte converts `any` to byte. func Byte(any interface{}) byte { if v, ok := any.(byte); ok { diff --git a/util/gconv/gconv_map.go b/util/gconv/gconv_map.go index 6685433b2..f2878ce56 100644 --- a/util/gconv/gconv_map.go +++ b/util/gconv/gconv_map.go @@ -166,7 +166,7 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string] dataMap[String(reflectValue.Index(i).Interface())] = nil } } - case reflect.Map, reflect.Struct: + case reflect.Map, reflect.Struct, reflect.Interface: convertedValue := doMapConvertForMapOrStructValue(true, value, recursive, newTags...) if m, ok := convertedValue.(map[string]interface{}); ok { return m diff --git a/util/gconv/gconv_scan.go b/util/gconv/gconv_scan.go index 4b481057b..1dd76832a 100644 --- a/util/gconv/gconv_scan.go +++ b/util/gconv/gconv_scan.go @@ -32,6 +32,7 @@ func Scan(params interface{}, pointer interface{}, mapping ...map[string]string) switch pointerElemKind { case reflect.Map: return MapToMap(params, pointer, mapping...) + case reflect.Array, reflect.Slice: var ( sliceElem = pointerElem.Elem() diff --git a/util/gconv/gconv_struct.go b/util/gconv/gconv_struct.go index 06416f863..354905ad2 100644 --- a/util/gconv/gconv_struct.go +++ b/util/gconv/gconv_struct.go @@ -84,6 +84,8 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string if rv, ok := pointer.(reflect.Value); ok { if rv.Kind() == reflect.Ptr { return json.UnmarshalUseNumber(r, rv.Interface()) + } else if rv.CanAddr() { + return json.UnmarshalUseNumber(r, rv.Addr().Interface()) } } else { return json.UnmarshalUseNumber(r, pointer) @@ -94,6 +96,8 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string if rv, ok := pointer.(reflect.Value); ok { if rv.Kind() == reflect.Ptr { return json.UnmarshalUseNumber(paramsBytes, rv.Interface()) + } else if rv.CanAddr() { + return json.UnmarshalUseNumber(paramsBytes, rv.Addr().Interface()) } } else { return json.UnmarshalUseNumber(paramsBytes, pointer) @@ -103,6 +107,7 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string var ( paramsReflectValue reflect.Value + paramsInterface interface{} // DO NOT use `params` directly as it might be type of `reflect.Value` pointerReflectValue reflect.Value pointerReflectKind reflect.Kind pointerElemReflectValue reflect.Value // The pointed element. @@ -112,6 +117,7 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string } else { paramsReflectValue = reflect.ValueOf(params) } + paramsInterface = paramsReflectValue.Interface() if v, ok := pointer.(reflect.Value); ok { pointerReflectValue = v pointerElemReflectValue = v @@ -135,7 +141,7 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string } // Normal unmarshalling interfaces checks. - if err, ok := bindVarToReflectValueWithInterfaceCheck(pointerReflectValue, params); ok { + if err, ok := bindVarToReflectValueWithInterfaceCheck(pointerReflectValue, paramsInterface); ok { return err } @@ -150,7 +156,7 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string // return v.UnmarshalValue(params) //} // Note that it's `pointerElemReflectValue` here not `pointerReflectValue`. - if err, ok := bindVarToReflectValueWithInterfaceCheck(pointerElemReflectValue, params); ok { + if err, ok := bindVarToReflectValueWithInterfaceCheck(pointerElemReflectValue, paramsInterface); ok { return err } // Retrieve its element, may be struct at last. @@ -159,7 +165,7 @@ func doStruct(params interface{}, pointer interface{}, mapping map[string]string // paramsMap is the map[string]interface{} type variable for params. // DO NOT use MapDeep here. - paramsMap := Map(params) + paramsMap := Map(paramsInterface) if paramsMap == nil { return gerror.Newf("convert params to map failed: %v", params) } @@ -300,7 +306,7 @@ func bindVarToStructAttr(elem reflect.Value, name string, value interface{}, map return nil } defer func() { - if e := recover(); e != nil { + if exception := recover(); exception != nil { if err = bindVarToReflectValue(structFieldValue, value, mapping, priorityTag); err != nil { err = gerror.Wrapf(err, `error binding value to attribute "%s"`, name) } @@ -310,7 +316,13 @@ func bindVarToStructAttr(elem reflect.Value, name string, value interface{}, map if empty.IsNil(value) { structFieldValue.Set(reflect.Zero(structFieldValue.Type())) } else { - structFieldValue.Set(reflect.ValueOf(Convert(value, structFieldValue.Type().String()))) + structFieldValue.Set(reflect.ValueOf(doConvert( + doConvertInput{ + FromValue: value, + ToTypeName: structFieldValue.Type().String(), + ReferValue: structFieldValue, + }, + ))) } return nil } @@ -331,9 +343,11 @@ func bindVarToReflectValueWithInterfaceCheck(reflectValue reflect.Value, value i } pointer = reflectValue.Interface() } + // UnmarshalValue. if v, ok := pointer.(apiUnmarshalValue); ok { return v.UnmarshalValue(value), ok } + // UnmarshalText. if v, ok := pointer.(apiUnmarshalText); ok { if s, ok := value.(string); ok { return v.UnmarshalText([]byte(s)), ok @@ -446,11 +460,12 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}, ma default: defer func() { - if e := recover(); e != nil { + if exception := recover(); exception != nil { err = gerror.New( - fmt.Sprintf(`cannot convert value "%+v" to type "%s"`, + fmt.Sprintf(`cannot convert value "%+v" to type "%s":%+v`, value, structFieldValue.Type().String(), + exception, ), ) } diff --git a/util/gconv/gconv_structs.go b/util/gconv/gconv_structs.go index 8d3626d36..d8aafef79 100644 --- a/util/gconv/gconv_structs.go +++ b/util/gconv/gconv_structs.go @@ -100,13 +100,34 @@ func doStructs(params interface{}, pointer interface{}, mapping map[string]strin } } // Converting `params` to map slice. - paramsMaps := Maps(params) + var ( + paramsList []interface{} + paramsRv = reflect.ValueOf(params) + paramsKind = paramsRv.Kind() + ) + for paramsKind == reflect.Ptr { + paramsRv = paramsRv.Elem() + paramsKind = paramsRv.Kind() + } + switch paramsKind { + case reflect.Slice, reflect.Array: + paramsList = make([]interface{}, paramsRv.Len()) + for i := 0; i < paramsRv.Len(); i++ { + paramsList[i] = paramsRv.Index(i) + } + default: + var paramsMaps = Maps(params) + paramsList = make([]interface{}, len(paramsMaps)) + for i := 0; i < len(paramsMaps); i++ { + paramsList[i] = paramsMaps[i] + } + } // If `params` is an empty slice, no conversion. - if len(paramsMaps) == 0 { + if len(paramsList) == 0 { return nil } var ( - reflectElemArray = reflect.MakeSlice(pointerRv.Type().Elem(), len(paramsMaps), len(paramsMaps)) + reflectElemArray = reflect.MakeSlice(pointerRv.Type().Elem(), len(paramsList), len(paramsList)) itemType = reflectElemArray.Index(0).Type() itemTypeKind = itemType.Kind() pointerRvElem = pointerRv.Elem() @@ -114,7 +135,7 @@ func doStructs(params interface{}, pointer interface{}, mapping map[string]strin ) if itemTypeKind == reflect.Ptr { // Pointer element. - for i := 0; i < len(paramsMaps); i++ { + for i := 0; i < len(paramsList); i++ { var tempReflectValue reflect.Value if i < pointerRvLength { // Might be nil. @@ -123,21 +144,21 @@ func doStructs(params interface{}, pointer interface{}, mapping map[string]strin if !tempReflectValue.IsValid() { tempReflectValue = reflect.New(itemType.Elem()).Elem() } - if err = doStruct(paramsMaps[i], tempReflectValue, mapping, priorityTag); err != nil { + if err = doStruct(paramsList[i], tempReflectValue, mapping, priorityTag); err != nil { return err } reflectElemArray.Index(i).Set(tempReflectValue.Addr()) } } else { // Struct element. - for i := 0; i < len(paramsMaps); i++ { + for i := 0; i < len(paramsList); i++ { var tempReflectValue reflect.Value if i < pointerRvLength { tempReflectValue = pointerRvElem.Index(i) } else { tempReflectValue = reflect.New(itemType).Elem() } - if err = doStruct(paramsMaps[i], tempReflectValue, mapping, priorityTag); err != nil { + if err = doStruct(paramsList[i], tempReflectValue, mapping, priorityTag); err != nil { return err } reflectElemArray.Index(i).Set(tempReflectValue) diff --git a/util/gconv/gconv_z_unit_scan_test.go b/util/gconv/gconv_z_unit_scan_test.go index d80448818..541baf7dc 100644 --- a/util/gconv/gconv_z_unit_scan_test.go +++ b/util/gconv/gconv_z_unit_scan_test.go @@ -58,7 +58,7 @@ func Test_Scan_StructStructs(t *testing.T) { } ) err := gconv.Scan(params, &users) - t.Assert(err, nil) + t.AssertNil(err) t.Assert(users, g.Slice{ &User{ Uid: 1, diff --git a/util/gconv/gconv_z_unit_struct_marshal_unmarshal_test.go b/util/gconv/gconv_z_unit_struct_marshal_unmarshal_test.go index 9fc082944..b7c0941f5 100644 --- a/util/gconv/gconv_z_unit_struct_marshal_unmarshal_test.go +++ b/util/gconv/gconv_z_unit_struct_marshal_unmarshal_test.go @@ -7,9 +7,9 @@ package gconv_test import ( - "errors" "github.com/gogf/gf/crypto/gcrc32" "github.com/gogf/gf/encoding/gbinary" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/test/gtest" @@ -83,16 +83,16 @@ func (p *Pkg) Marshal() []byte { func (p *Pkg) UnmarshalValue(v interface{}) error { b := gconv.Bytes(v) if len(b) < 6 { - return errors.New("invalid package length") + return gerror.New("invalid package length") } p.Length = gbinary.DecodeToUint16(b[:2]) if len(b) < int(p.Length) { - return errors.New("invalid data length") + return gerror.New("invalid data length") } p.Crc32 = gbinary.DecodeToUint32(b[2:6]) p.Data = b[6:] if gcrc32.Encrypt(p.Data) != p.Crc32 { - return errors.New("crc32 validation failed") + return gerror.New("crc32 validation failed") } return nil } diff --git a/util/gmode/gmode.go b/util/gmode/gmode.go index 9b43dacaf..199e89a0d 100644 --- a/util/gmode/gmode.go +++ b/util/gmode/gmode.go @@ -16,15 +16,16 @@ import ( ) const ( - NOT_SET = "not-set" - DEVELOP = "develop" - TESTING = "testing" - STAGING = "staging" - PRODUCT = "product" - cmdEnvKey = "gf.gmode" + NOT_SET = "not-set" + DEVELOP = "develop" + TESTING = "testing" + STAGING = "staging" + PRODUCT = "product" + commandEnvKey = "gf.gmode" ) var ( + // Note that `currentMode` is not concurrent safe. currentMode = NOT_SET ) @@ -57,7 +58,7 @@ func SetProduct() { func Mode() string { // If current mode is not set, do this auto check. if currentMode == NOT_SET { - if v := gcmd.GetOptWithEnv(cmdEnvKey).String(); v != "" { + if v := gcmd.GetOptWithEnv(commandEnvKey).String(); v != "" { // Mode configured from command argument of environment. currentMode = v } else { diff --git a/util/gpage/gpage.go b/util/gpage/gpage.go index ef9b9e2d3..a0b05ef3c 100644 --- a/util/gpage/gpage.go +++ b/util/gpage/gpage.go @@ -35,8 +35,8 @@ type Page struct { } const ( - PAGE_NAME = "page" // PAGE_NAME defines the default page name. - PAGE_PLACE_HOLDER = "{.page}" // PAGE_PLACE_HOLDER defines the place holder for the url template. + DefaultPageName = "page" // DefaultPageName defines the default page name. + DefaultPagePlaceHolder = "{.page}" // DefaultPagePlaceHolder defines the place holder for the url template. ) // New creates and returns a pagination manager. @@ -206,7 +206,7 @@ func (p *Page) GetContent(mode int) string { // Note that the UrlTemplate attribute can be either an URL or a URI string with "{.page}" // place holder specifying the page number position. func (p *Page) GetUrl(page int) string { - return gstr.Replace(p.UrlTemplate, PAGE_PLACE_HOLDER, gconv.String(page)) + return gstr.Replace(p.UrlTemplate, DefaultPagePlaceHolder, gconv.String(page)) } // GetLink returns the HTML link tag `a` content for given page number. diff --git a/util/gutil/gutil_comparator.go b/util/gutil/gutil_comparator.go index 39c364868..0dbf4b98d 100644 --- a/util/gutil/gutil_comparator.go +++ b/util/gutil/gutil_comparator.go @@ -77,8 +77,8 @@ func ComparatorUint64(a, b interface{}) int { // ComparatorFloat32 provides a basic comparison on float32. func ComparatorFloat32(a, b interface{}) int { - aFloat := gconv.Float64(a) - bFloat := gconv.Float64(b) + aFloat := gconv.Float32(a) + bFloat := gconv.Float32(b) if aFloat == bFloat { return 0 } diff --git a/util/gutil/gutil_slice.go b/util/gutil/gutil_slice.go index c2a183495..7da40a8db 100644 --- a/util/gutil/gutil_slice.go +++ b/util/gutil/gutil_slice.go @@ -65,3 +65,29 @@ func SliceToMap(slice interface{}) map[string]interface{} { } return nil } + +// SliceToMapWithColumnAsKey converts slice type variable `slice` to `map[interface{}]interface{}` +// The value of specified column use as the key for returned map. +// Eg: +// SliceToMapWithColumnAsKey([{"K1": "v1", "K2": 1}, {"K1": "v2", "K2": 2}], "K1") => {"v1": {"K1": "v1", "K2": 1}, "v2": {"K1": "v2", "K2": 2}} +// SliceToMapWithColumnAsKey([{"K1": "v1", "K2": 1}, {"K1": "v2", "K2": 2}], "K2") => {1: {"K1": "v1", "K2": 1}, 2: {"K1": "v2", "K2": 2}} +func SliceToMapWithColumnAsKey(slice interface{}, key interface{}) map[interface{}]interface{} { + var ( + reflectValue = reflect.ValueOf(slice) + reflectKind = reflectValue.Kind() + ) + for reflectKind == reflect.Ptr { + reflectValue = reflectValue.Elem() + reflectKind = reflectValue.Kind() + } + data := make(map[interface{}]interface{}) + switch reflectKind { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + if k, ok := ItemValue(reflectValue.Index(i), key); ok { + data[k] = reflectValue.Index(i).Interface() + } + } + } + return data +} diff --git a/util/gutil/gutil_z_unit_slice_test.go b/util/gutil/gutil_z_unit_slice_test.go index 0f3f61938..218b1b65b 100755 --- a/util/gutil/gutil_z_unit_slice_test.go +++ b/util/gutil/gutil_z_unit_slice_test.go @@ -35,3 +35,23 @@ func Test_SliceToMap(t *testing.T) { t.Assert(m, nil) }) } + +func Test_SliceToMapWithColumnAsKey(t *testing.T) { + m1 := g.Map{"K1": "v1", "K2": 1} + m2 := g.Map{"K1": "v2", "K2": 2} + s := g.Slice{m1, m2} + gtest.C(t, func(t *gtest.T) { + m := gutil.SliceToMapWithColumnAsKey(s, "K1") + t.Assert(m, g.MapAnyAny{ + "v1": m1, + "v2": m2, + }) + }) + gtest.C(t, func(t *gtest.T) { + m := gutil.SliceToMapWithColumnAsKey(s, "K2") + t.Assert(m, g.MapAnyAny{ + 1: m1, + 2: m2, + }) + }) +} diff --git a/version.go b/version.go index d2895b5ad..bd9e356f1 100644 --- a/version.go +++ b/version.go @@ -1,4 +1,4 @@ package gf -const VERSION = "v1.16.0" +const VERSION = "v1.16.4" const AUTHORS = "john<john@goframe.org>"