add CtxStrict feature for package gdb

This commit is contained in:
John Guo 2021-06-26 18:20:55 +08:00
parent 859ea150ed
commit b958689264
15 changed files with 133 additions and 66 deletions

View File

@ -70,8 +70,6 @@ type DB interface {
// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy
// of current DB object and with given context in it.
// Note that this returned DB object can be used only once, so do not assign it to
// a global or package variable for long using.
// Also see Core.Ctx.
Ctx(ctx context.Context) DB
@ -105,23 +103,21 @@ type DB interface {
DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete.
DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery.
DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec.
DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) // See Core.DoCommit.
DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoCommit.
DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare.
// ===========================================================================
// Query APIs for convenience purpose.
// ===========================================================================
GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll.
GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne.
GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue.
GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray.
GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount.
GetStruct(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetStruct.
GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error // See Core.GetStructs.
GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan.
Union(unions ...*Model) *Model // See Core.Union.
UnionAll(unions ...*Model) *Model // See Core.UnionAll.
GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll.
GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne.
GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue.
GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray.
GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount.
GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan.
Union(unions ...*Model) *Model // See Core.Union.
UnionAll(unions ...*Model) *Model // See Core.UnionAll.
// ===========================================================================
// Master/Slave specification support.
@ -173,7 +169,7 @@ type DB interface {
GetChars() (charLeft string, charRight string) // See Core.GetChars.
Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables.
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
FilteredLink() string
FilteredLink() string // FilteredLink is used for filtering sensitive information in `Link` configuration before output it to tracing server.
}
// Core is the base struct for database management.
@ -270,6 +266,8 @@ const (
ctxTimeoutTypeQuery
ctxTimeoutTypePrepare
commandEnvKeyForDryRun = "gf.gdb.dryrun"
ctxStrictKeyName = "gf.gdb.CtxStrictEnabled"
ctxStrictErrorStr = "context is required for database operation, did you missing call function Ctx"
)
var (

View File

@ -40,6 +40,7 @@ func (c *Core) Ctx(ctx context.Context) DB {
if c.ctx != nil {
return c.db
}
ctx = context.WithValue(ctx, ctxStrictKeyName, 1)
// It makes a shallow copy of current db and changes its context for next chaining operation.
var (
err error
@ -189,9 +190,9 @@ func (c *Core) GetScan(pointer interface{}, sql string, args ...interface{}) err
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return c.db.GetStructs(pointer, sql, args...)
return c.db.GetCore().GetStructs(pointer, sql, args...)
case reflect.Struct:
return c.db.GetStruct(pointer, sql, args...)
return c.db.GetCore().GetStruct(pointer, sql, args...)
}
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}

View File

@ -49,6 +49,7 @@ type ConfigNode struct {
UpdatedAt string `json:"updatedAt"` // (Optional) The filed name of table for automatic-filled updated datetime.
DeletedAt string `json:"deletedAt"` // (Optional) The filed name of table for automatic-filled updated datetime.
TimeMaintainDisabled bool `json:"timeMaintainDisabled"` // (Optional) Disable the automatic time maintaining feature.
CtxStrict bool `json:"ctxStrict"` // (Optional) Strictly require context input for all database operations.
}
const (

View File

@ -10,6 +10,7 @@ package gdb
import (
"context"
"database/sql"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
)
@ -33,12 +34,17 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
link = &txLink{tx.tx}
}
}
// Link execution.
sql, args = formatSql(sql, args)
sql, args = c.db.DoCommit(ctx, link, sql, args)
if c.GetConfig().QueryTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout)
}
// Link execution.
sql, args = formatSql(sql, args)
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
if err != nil {
return nil, err
}
mTime1 := gtime.TimestampMilli()
rows, err = link.QueryContext(ctx, sql, args...)
mTime2 := gtime.TimestampMilli()
@ -85,15 +91,19 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
link = &txLink{tx.tx}
}
}
// Link execution.
sql, args = formatSql(sql, args)
sql, args = c.db.DoCommit(ctx, link, sql, args)
if c.GetConfig().ExecTimeout > 0 {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().ExecTimeout)
defer cancelFunc()
}
// Link execution.
sql, args = formatSql(sql, args)
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
if err != nil {
return nil, err
}
mTime1 := gtime.TimestampMilli()
if !c.db.GetDryRun() {
result, err = link.ExecContext(ctx, sql, args...)
@ -120,6 +130,18 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
return result, formatError(err, sql, args...)
}
// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver.
// The parameter `link` specifies the current database connection operation object. You can modify the sql
// string `sql` and its arguments `args` as you wish before they're committed to driver.
func (c *Core) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
if c.db.GetConfig().CtxStrict {
if v := ctx.Value(ctxStrictKeyName); v == nil {
return sql, args, gerror.New(ctxStrictErrorStr)
}
}
return sql, args, nil
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
@ -156,6 +178,13 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, err
// DO NOT USE cancel function in prepare statement.
ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
}
if c.db.GetConfig().CtxStrict {
if v := ctx.Value(ctxStrictKeyName); v == nil {
return nil, gerror.New(ctxStrictErrorStr)
}
}
var (
mTime1 = gtime.TimestampMilli()
stmt, err = link.PrepareContext(ctx, sql)

View File

@ -63,14 +63,6 @@ func (c *Core) GetChars() (charLeft string, charRight string) {
return "", ""
}
// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver.
// The parameter `link` specifies the current database connection operation object. You can modify the sql
// string `sql` and its arguments `args` as you wish before they're committed to driver.
// Also see Core.DoCommit.
func (c *Core) DoCommit(sql string) string {
return sql
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
//

View File

@ -79,7 +79,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
}()
var index int
// Convert place holder char '?' to string "@px".
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
@ -87,7 +90,7 @@ func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args
return fmt.Sprintf("@p%d", index)
})
str, _ = gregex.ReplaceString("\"", "", str)
return d.parseSql(str), args
return d.parseSql(str), args, nil
}
// parseSql does some replacement of the sql before commits it to underlying driver,

View File

@ -81,8 +81,8 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
}
// DoCommit handles the sql before posts it to database.
func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoCommit(ctx, link, sql, args)
}
// Tables retrieves and returns the tables of current schema.

View File

@ -32,11 +32,6 @@ type DriverOracle struct {
*Core
}
const (
tableAlias1 = "GFORM1"
tableAlias2 = "GFORM2"
)
// New creates and returns a database object for oracle.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) {
@ -85,7 +80,11 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) {
func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
}()
var index int
// Convert place holder char '?' to string ":vx".
newSql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {

View File

@ -80,15 +80,19 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
}()
var index int
// Convert place holder char '?' to string "$x".
sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
index++
return fmt.Sprintf("$%d", index)
})
sql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql)
return sql, args
newSql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql)
return newSql, args, nil
}
// Tables retrieves and returns the tables of current schema.

View File

@ -67,10 +67,8 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
// TODO 需要增加对Save方法的支持可使用正则来实现替换
// TODO 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoCommit(ctx, link, sql, args)
}
// Tables retrieves and returns the tables of current schema.

View File

@ -37,7 +37,7 @@ const (
)
// doStmtCommit commits statement according to given `stmtType`.
func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) {
func (s *Stmt) doStmtCommit(ctx context.Context, stmtType string, args ...interface{}) (result interface{}, err error) {
var (
cancelFuncForTimeout context.CancelFunc
timestampMilli1 = gtime.TimestampMilli()
@ -86,7 +86,7 @@ func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interf
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...)
result, err := s.doStmtCommit(ctx, stmtTypeExecContext, args...)
if result != nil {
return result.(sql.Result), err
}
@ -96,7 +96,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...)
result, err := s.doStmtCommit(ctx, stmtTypeQueryContext, args...)
if result != nil {
return result.(*sql.Rows), err
}
@ -110,7 +110,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...)
result, _ := s.doStmtCommit(ctx, stmtTypeQueryRowContext, args...)
if result != nil {
return result.(*sql.Row)
}

View File

@ -43,7 +43,7 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
// DoCommit handles the sql before posts it to database.
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
latestSqlString.Set(sql)
return d.DriverMysql.DoCommit(ctx, link, sql, args)
}

View File

@ -7,6 +7,7 @@
package gdb_test
import (
"context"
"fmt"
"github.com/gogf/gf/container/garray"
"github.com/gogf/gf/frame/g"
@ -29,9 +30,10 @@ const (
)
var (
db gdb.DB
dbPrefix gdb.DB
configNode gdb.ConfigNode
db gdb.DB
dbPrefix gdb.DB
dbCtxStrict gdb.DB
configNode gdb.ConfigNode
)
func init() {
@ -56,9 +58,15 @@ func init() {
}
nodePrefix := configNode
nodePrefix.Prefix = TableNamePrefix1
nodeCtxStrict := configNode
nodeCtxStrict.CtxStrict = true
gdb.AddConfigNode("test", configNode)
gdb.AddConfigNode("prefix", nodePrefix)
gdb.AddConfigNode("ctxstrict", nodeCtxStrict)
gdb.AddConfigNode(gdb.DefaultGroupName, configNode)
// Default db.
if r, err := gdb.New(); err != nil {
gtest.Error(err)
@ -87,6 +95,20 @@ func init() {
gtest.Error(err)
}
dbPrefix.SetSchema(TestSchema1)
// CtxStrict db.
if r, err := gdb.New("ctxstrict"); err != nil {
gtest.Error(err)
} else {
dbCtxStrict = r
}
if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema1)); err != nil {
gtest.Error(err)
}
if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema2)); err != nil {
gtest.Error(err)
}
dbCtxStrict.SetSchema(TestSchema1)
}
func createTable(table ...string) string {
@ -111,7 +133,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
switch configNode.Type {
case "sqlite":
if _, err := db.Exec(fmt.Sprintf(`
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
CREATE TABLE %s (
id bigint unsigned NOT NULL AUTO_INCREMENT,
passport varchar(45),
@ -124,7 +146,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
gtest.Fatal(err)
}
case "pgsql":
if _, err := db.Exec(fmt.Sprintf(`
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
CREATE TABLE %s (
id bigint NOT NULL,
passport varchar(45),
@ -137,7 +159,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
gtest.Fatal(err)
}
case "mssql":
if _, err := db.Exec(fmt.Sprintf(`
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='%s' and xtype='U')
CREATE TABLE %s (
ID numeric(10,0) NOT NULL,
@ -151,7 +173,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
gtest.Fatal(err)
}
case "oracle":
if _, err := db.Exec(fmt.Sprintf(`
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
CREATE TABLE %s (
ID NUMBER(10) NOT NULL,
PASSPORT VARCHAR(45) NOT NULL,
@ -164,7 +186,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
gtest.Fatal(err)
}
case "mysql":
if _, err := db.Exec(fmt.Sprintf(`
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
CREATE TABLE %s (
id int(10) unsigned NOT NULL AUTO_INCREMENT,
passport varchar(45) NULL,
@ -195,7 +217,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) {
})
}
result, err := db.Insert(name, array.Slice())
result, err := db.Ctx(context.TODO()).Insert(name, array.Slice())
gtest.AssertNil(err)
n, e := result.RowsAffected()
@ -205,7 +227,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) {
}
func dropTableWithDb(db gdb.DB, table string) {
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
gtest.Error(err)
}
}

View File

@ -62,3 +62,23 @@ func Test_Ctx_Model(t *testing.T) {
db.Model(table).All()
})
}
func Test_Ctx_Strict(t *testing.T) {
table := createInitTableWithDb(dbCtxStrict)
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
_, err := dbCtxStrict.Query("select 1")
t.AssertNE(err, nil)
})
gtest.C(t, func(t *gtest.T) {
r, err := dbCtxStrict.Model(table).All()
t.AssertNE(err, nil)
t.Assert(len(r), 0)
})
gtest.C(t, func(t *gtest.T) {
r, err := dbCtxStrict.Model(table).Ctx(context.TODO()).All()
t.AssertNil(err)
t.Assert(len(r), TableSize)
})
}

View File

@ -584,7 +584,7 @@ func Test_DB_GetStruct(t *testing.T) {
CreateTime gtime.Time
}
user := new(User)
err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
t.AssertNil(err)
t.Assert(user.NickName, "name_3")
})
@ -597,7 +597,7 @@ func Test_DB_GetStruct(t *testing.T) {
CreateTime *gtime.Time
}
user := new(User)
err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
t.AssertNil(err)
t.Assert(user.NickName, "name_3")
})
@ -615,7 +615,7 @@ func Test_DB_GetStructs(t *testing.T) {
CreateTime gtime.Time
}
var users []User
err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
t.AssertNil(err)
t.Assert(len(users), TableSize-1)
t.Assert(users[0].Id, 2)
@ -635,7 +635,7 @@ func Test_DB_GetStructs(t *testing.T) {
CreateTime *gtime.Time
}
var users []User
err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
t.AssertNil(err)
t.Assert(len(users), TableSize-1)
t.Assert(users[0].Id, 2)