mirror of
https://gitee.com/johng/gf.git
synced 2024-11-30 03:07:45 +08:00
add CtxStrict feature for package gdb
This commit is contained in:
parent
859ea150ed
commit
b958689264
@ -70,8 +70,6 @@ type DB interface {
|
||||
|
||||
// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy
|
||||
// of current DB object and with given context in it.
|
||||
// Note that this returned DB object can be used only once, so do not assign it to
|
||||
// a global or package variable for long using.
|
||||
// Also see Core.Ctx.
|
||||
Ctx(ctx context.Context) DB
|
||||
|
||||
@ -105,23 +103,21 @@ type DB interface {
|
||||
DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete.
|
||||
DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery.
|
||||
DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec.
|
||||
DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) // See Core.DoCommit.
|
||||
DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoCommit.
|
||||
DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare.
|
||||
|
||||
// ===========================================================================
|
||||
// Query APIs for convenience purpose.
|
||||
// ===========================================================================
|
||||
|
||||
GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll.
|
||||
GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne.
|
||||
GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue.
|
||||
GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray.
|
||||
GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount.
|
||||
GetStruct(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetStruct.
|
||||
GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error // See Core.GetStructs.
|
||||
GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan.
|
||||
Union(unions ...*Model) *Model // See Core.Union.
|
||||
UnionAll(unions ...*Model) *Model // See Core.UnionAll.
|
||||
GetAll(sql string, args ...interface{}) (Result, error) // See Core.GetAll.
|
||||
GetOne(sql string, args ...interface{}) (Record, error) // See Core.GetOne.
|
||||
GetValue(sql string, args ...interface{}) (Value, error) // See Core.GetValue.
|
||||
GetArray(sql string, args ...interface{}) ([]Value, error) // See Core.GetArray.
|
||||
GetCount(sql string, args ...interface{}) (int, error) // See Core.GetCount.
|
||||
GetScan(objPointer interface{}, sql string, args ...interface{}) error // See Core.GetScan.
|
||||
Union(unions ...*Model) *Model // See Core.Union.
|
||||
UnionAll(unions ...*Model) *Model // See Core.UnionAll.
|
||||
|
||||
// ===========================================================================
|
||||
// Master/Slave specification support.
|
||||
@ -173,7 +169,7 @@ type DB interface {
|
||||
GetChars() (charLeft string, charRight string) // See Core.GetChars.
|
||||
Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables.
|
||||
TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
|
||||
FilteredLink() string
|
||||
FilteredLink() string // FilteredLink is used for filtering sensitive information in `Link` configuration before output it to tracing server.
|
||||
}
|
||||
|
||||
// Core is the base struct for database management.
|
||||
@ -270,6 +266,8 @@ const (
|
||||
ctxTimeoutTypeQuery
|
||||
ctxTimeoutTypePrepare
|
||||
commandEnvKeyForDryRun = "gf.gdb.dryrun"
|
||||
ctxStrictKeyName = "gf.gdb.CtxStrictEnabled"
|
||||
ctxStrictErrorStr = "context is required for database operation, did you missing call function Ctx"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,6 +40,7 @@ func (c *Core) Ctx(ctx context.Context) DB {
|
||||
if c.ctx != nil {
|
||||
return c.db
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxStrictKeyName, 1)
|
||||
// It makes a shallow copy of current db and changes its context for next chaining operation.
|
||||
var (
|
||||
err error
|
||||
@ -189,9 +190,9 @@ func (c *Core) GetScan(pointer interface{}, sql string, args ...interface{}) err
|
||||
k = t.Elem().Kind()
|
||||
switch k {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return c.db.GetStructs(pointer, sql, args...)
|
||||
return c.db.GetCore().GetStructs(pointer, sql, args...)
|
||||
case reflect.Struct:
|
||||
return c.db.GetStruct(pointer, sql, args...)
|
||||
return c.db.GetCore().GetStruct(pointer, sql, args...)
|
||||
}
|
||||
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
|
||||
}
|
||||
|
@ -49,6 +49,7 @@ type ConfigNode struct {
|
||||
UpdatedAt string `json:"updatedAt"` // (Optional) The filed name of table for automatic-filled updated datetime.
|
||||
DeletedAt string `json:"deletedAt"` // (Optional) The filed name of table for automatic-filled updated datetime.
|
||||
TimeMaintainDisabled bool `json:"timeMaintainDisabled"` // (Optional) Disable the automatic time maintaining feature.
|
||||
CtxStrict bool `json:"ctxStrict"` // (Optional) Strictly require context input for all database operations.
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -10,6 +10,7 @@ package gdb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/gogf/gf/errors/gerror"
|
||||
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
)
|
||||
@ -33,12 +34,17 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
|
||||
link = &txLink{tx.tx}
|
||||
}
|
||||
}
|
||||
// Link execution.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.db.DoCommit(ctx, link, sql, args)
|
||||
|
||||
if c.GetConfig().QueryTimeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout)
|
||||
}
|
||||
|
||||
// Link execution.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
rows, err = link.QueryContext(ctx, sql, args...)
|
||||
mTime2 := gtime.TimestampMilli()
|
||||
@ -85,15 +91,19 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
|
||||
link = &txLink{tx.tx}
|
||||
}
|
||||
}
|
||||
// Link execution.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.db.DoCommit(ctx, link, sql, args)
|
||||
|
||||
if c.GetConfig().ExecTimeout > 0 {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().ExecTimeout)
|
||||
defer cancelFunc()
|
||||
}
|
||||
|
||||
// Link execution.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
if !c.db.GetDryRun() {
|
||||
result, err = link.ExecContext(ctx, sql, args...)
|
||||
@ -120,6 +130,18 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
|
||||
return result, formatError(err, sql, args...)
|
||||
}
|
||||
|
||||
// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver.
|
||||
// The parameter `link` specifies the current database connection operation object. You can modify the sql
|
||||
// string `sql` and its arguments `args` as you wish before they're committed to driver.
|
||||
func (c *Core) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
if c.db.GetConfig().CtxStrict {
|
||||
if v := ctx.Value(ctxStrictKeyName); v == nil {
|
||||
return sql, args, gerror.New(ctxStrictErrorStr)
|
||||
}
|
||||
}
|
||||
return sql, args, nil
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
// Multiple queries or executions may be run concurrently from the
|
||||
// returned statement.
|
||||
@ -156,6 +178,13 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, err
|
||||
// DO NOT USE cancel function in prepare statement.
|
||||
ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
|
||||
}
|
||||
|
||||
if c.db.GetConfig().CtxStrict {
|
||||
if v := ctx.Value(ctxStrictKeyName); v == nil {
|
||||
return nil, gerror.New(ctxStrictErrorStr)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
mTime1 = gtime.TimestampMilli()
|
||||
stmt, err = link.PrepareContext(ctx, sql)
|
||||
|
@ -63,14 +63,6 @@ func (c *Core) GetChars() (charLeft string, charRight string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver.
|
||||
// The parameter `link` specifies the current database connection operation object. You can modify the sql
|
||||
// string `sql` and its arguments `args` as you wish before they're committed to driver.
|
||||
// Also see Core.DoCommit.
|
||||
func (c *Core) DoCommit(sql string) string {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
//
|
||||
|
@ -79,7 +79,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// DoCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
defer func() {
|
||||
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
|
||||
}()
|
||||
var index int
|
||||
// Convert place holder char '?' to string "@px".
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
@ -87,7 +90,7 @@ func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args
|
||||
return fmt.Sprintf("@p%d", index)
|
||||
})
|
||||
str, _ = gregex.ReplaceString("\"", "", str)
|
||||
return d.parseSql(str), args
|
||||
return d.parseSql(str), args, nil
|
||||
}
|
||||
|
||||
// parseSql does some replacement of the sql before commits it to underlying driver,
|
||||
|
@ -81,8 +81,8 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// DoCommit handles the sql before posts it to database.
|
||||
func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
return sql, args
|
||||
func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
return d.Core.DoCommit(ctx, link, sql, args)
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
|
@ -32,11 +32,6 @@ type DriverOracle struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
const (
|
||||
tableAlias1 = "GFORM1"
|
||||
tableAlias2 = "GFORM2"
|
||||
)
|
||||
|
||||
// New creates and returns a database object for oracle.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
@ -85,7 +80,11 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// DoCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) {
|
||||
func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
defer func() {
|
||||
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
|
||||
}()
|
||||
|
||||
var index int
|
||||
// Convert place holder char '?' to string ":vx".
|
||||
newSql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
|
@ -80,15 +80,19 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// DoCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
defer func() {
|
||||
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
|
||||
}()
|
||||
|
||||
var index int
|
||||
// Convert place holder char '?' to string "$x".
|
||||
sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
index++
|
||||
return fmt.Sprintf("$%d", index)
|
||||
})
|
||||
sql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql)
|
||||
return sql, args
|
||||
newSql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $2 OFFSET $1`, sql)
|
||||
return newSql, args, nil
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
|
@ -67,10 +67,8 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// DoCommit deals with the sql string before commits it to underlying sql driver.
|
||||
// TODO 需要增加对Save方法的支持,可使用正则来实现替换,
|
||||
// TODO 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
|
||||
func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
return sql, args
|
||||
func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
return d.Core.DoCommit(ctx, link, sql, args)
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
|
@ -37,7 +37,7 @@ const (
|
||||
)
|
||||
|
||||
// doStmtCommit commits statement according to given `stmtType`.
|
||||
func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) {
|
||||
func (s *Stmt) doStmtCommit(ctx context.Context, stmtType string, args ...interface{}) (result interface{}, err error) {
|
||||
var (
|
||||
cancelFuncForTimeout context.CancelFunc
|
||||
timestampMilli1 = gtime.TimestampMilli()
|
||||
@ -86,7 +86,7 @@ func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interf
|
||||
// ExecContext executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...)
|
||||
result, err := s.doStmtCommit(ctx, stmtTypeExecContext, args...)
|
||||
if result != nil {
|
||||
return result.(sql.Result), err
|
||||
}
|
||||
@ -96,7 +96,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
|
||||
// QueryContext executes a prepared query statement with the given arguments
|
||||
// and returns the query results as a *Rows.
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||
result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...)
|
||||
result, err := s.doStmtCommit(ctx, stmtTypeQueryContext, args...)
|
||||
if result != nil {
|
||||
return result.(*sql.Rows), err
|
||||
}
|
||||
@ -110,7 +110,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards
|
||||
// the rest.
|
||||
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
|
||||
result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...)
|
||||
result, _ := s.doStmtCommit(ctx, stmtTypeQueryRowContext, args...)
|
||||
if result != nil {
|
||||
return result.(*sql.Row)
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
|
||||
|
||||
// DoCommit handles the sql before posts it to database.
|
||||
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
|
||||
func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
|
||||
latestSqlString.Set(sql)
|
||||
return d.DriverMysql.DoCommit(ctx, link, sql, args)
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/container/garray"
|
||||
"github.com/gogf/gf/frame/g"
|
||||
@ -29,9 +30,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
db gdb.DB
|
||||
dbPrefix gdb.DB
|
||||
configNode gdb.ConfigNode
|
||||
db gdb.DB
|
||||
dbPrefix gdb.DB
|
||||
dbCtxStrict gdb.DB
|
||||
configNode gdb.ConfigNode
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -56,9 +58,15 @@ func init() {
|
||||
}
|
||||
nodePrefix := configNode
|
||||
nodePrefix.Prefix = TableNamePrefix1
|
||||
|
||||
nodeCtxStrict := configNode
|
||||
nodeCtxStrict.CtxStrict = true
|
||||
|
||||
gdb.AddConfigNode("test", configNode)
|
||||
gdb.AddConfigNode("prefix", nodePrefix)
|
||||
gdb.AddConfigNode("ctxstrict", nodeCtxStrict)
|
||||
gdb.AddConfigNode(gdb.DefaultGroupName, configNode)
|
||||
|
||||
// Default db.
|
||||
if r, err := gdb.New(); err != nil {
|
||||
gtest.Error(err)
|
||||
@ -87,6 +95,20 @@ func init() {
|
||||
gtest.Error(err)
|
||||
}
|
||||
dbPrefix.SetSchema(TestSchema1)
|
||||
|
||||
// CtxStrict db.
|
||||
if r, err := gdb.New("ctxstrict"); err != nil {
|
||||
gtest.Error(err)
|
||||
} else {
|
||||
dbCtxStrict = r
|
||||
}
|
||||
if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema1)); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
if _, err := dbCtxStrict.Ctx(context.TODO()).Exec(fmt.Sprintf(schemaTemplate, TestSchema2)); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
dbCtxStrict.SetSchema(TestSchema1)
|
||||
}
|
||||
|
||||
func createTable(table ...string) string {
|
||||
@ -111,7 +133,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
|
||||
switch configNode.Type {
|
||||
case "sqlite":
|
||||
if _, err := db.Exec(fmt.Sprintf(`
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id bigint unsigned NOT NULL AUTO_INCREMENT,
|
||||
passport varchar(45),
|
||||
@ -124,7 +146,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
case "pgsql":
|
||||
if _, err := db.Exec(fmt.Sprintf(`
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id bigint NOT NULL,
|
||||
passport varchar(45),
|
||||
@ -137,7 +159,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
case "mssql":
|
||||
if _, err := db.Exec(fmt.Sprintf(`
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
|
||||
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='%s' and xtype='U')
|
||||
CREATE TABLE %s (
|
||||
ID numeric(10,0) NOT NULL,
|
||||
@ -151,7 +173,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
case "oracle":
|
||||
if _, err := db.Exec(fmt.Sprintf(`
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
ID NUMBER(10) NOT NULL,
|
||||
PASSPORT VARCHAR(45) NOT NULL,
|
||||
@ -164,7 +186,7 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
case "mysql":
|
||||
if _, err := db.Exec(fmt.Sprintf(`
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id int(10) unsigned NOT NULL AUTO_INCREMENT,
|
||||
passport varchar(45) NULL,
|
||||
@ -195,7 +217,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
})
|
||||
}
|
||||
|
||||
result, err := db.Insert(name, array.Slice())
|
||||
result, err := db.Ctx(context.TODO()).Insert(name, array.Slice())
|
||||
gtest.AssertNil(err)
|
||||
|
||||
n, e := result.RowsAffected()
|
||||
@ -205,7 +227,7 @@ func createInitTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
}
|
||||
|
||||
func dropTableWithDb(db gdb.DB, table string) {
|
||||
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
||||
if _, err := db.Ctx(context.TODO()).Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
}
|
||||
|
@ -62,3 +62,23 @@ func Test_Ctx_Model(t *testing.T) {
|
||||
db.Model(table).All()
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Ctx_Strict(t *testing.T) {
|
||||
table := createInitTableWithDb(dbCtxStrict)
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
_, err := dbCtxStrict.Query("select 1")
|
||||
t.AssertNE(err, nil)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
r, err := dbCtxStrict.Model(table).All()
|
||||
t.AssertNE(err, nil)
|
||||
t.Assert(len(r), 0)
|
||||
})
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
r, err := dbCtxStrict.Model(table).Ctx(context.TODO()).All()
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(r), TableSize)
|
||||
})
|
||||
}
|
||||
|
@ -584,7 +584,7 @@ func Test_DB_GetStruct(t *testing.T) {
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
|
||||
err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
|
||||
t.AssertNil(err)
|
||||
t.Assert(user.NickName, "name_3")
|
||||
})
|
||||
@ -597,7 +597,7 @@ func Test_DB_GetStruct(t *testing.T) {
|
||||
CreateTime *gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
err := db.GetStruct(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
|
||||
err := db.GetScan(user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3)
|
||||
t.AssertNil(err)
|
||||
t.Assert(user.NickName, "name_3")
|
||||
})
|
||||
@ -615,7 +615,7 @@ func Test_DB_GetStructs(t *testing.T) {
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
var users []User
|
||||
err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
|
||||
err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(users), TableSize-1)
|
||||
t.Assert(users[0].Id, 2)
|
||||
@ -635,7 +635,7 @@ func Test_DB_GetStructs(t *testing.T) {
|
||||
CreateTime *gtime.Time
|
||||
}
|
||||
var users []User
|
||||
err := db.GetStructs(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
|
||||
err := db.GetScan(&users, fmt.Sprintf("SELECT * FROM %s WHERE id>?", table), 1)
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(users), TableSize-1)
|
||||
t.Assert(users[0].Id, 2)
|
||||
|
Loading…
Reference in New Issue
Block a user