add gdb.DB.DoFilter; improve function gdb.DB.DoCommit for package gdb

This commit is contained in:
John Guo 2021-12-27 20:51:26 +08:00
parent 14e96069f2
commit b00de2c617
10 changed files with 177 additions and 149 deletions

View File

@ -100,7 +100,8 @@ type DB interface {
DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete.
DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoQuery.
DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec.
DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoCommit.
DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoFilter.
DoCommit(ctx context.Context, in DoCommitInput) (out *DoCommitOutput, err error) // See Core.DoCommit.
DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare.
// ===========================================================================
@ -183,6 +184,22 @@ type Core struct {
config *ConfigNode // Current config node.
}
// DoCommitInput is the input parameters for function DoCommit.
type DoCommitInput struct {
Stmt *sql.Stmt
Link Link
Sql string
Args []interface{}
Type string
}
// DoCommitOutput is the output parameters for function DoCommit.
type DoCommitOutput struct {
Row *sql.Row // Row is the result of Stmt.QueryRowContext.
Rows *sql.Rows // Rows is the result of query statement.
Result sql.Result // Result is the result of exec statement.
}
// Driver is the interface for integrating sql drivers into package gdb.
type Driver interface {
// New creates and returns a database object for specified database server.
@ -278,6 +295,14 @@ const (
dbRoleSlave = `slave`
)
const (
DoCommitTypeExecContext = "ExecContext"
DoCommitTypeQueryContext = "QueryContext"
DoCommitTypeStmtExecContext = "Statement.ExecContext"
DoCommitTypeStmtQueryContext = "Statement.QueryContext"
DoCommitTypeStmtQueryRowContext = "Statement.QueryRowContext"
)
var (
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)

View File

@ -12,6 +12,8 @@ import (
"database/sql"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/internal/intlog"
"github.com/gogf/gf/v2/os/gtime"
)
@ -45,40 +47,27 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout)
}
// Link execution.
// Sql filtering.
sql, args = formatSql(sql, args)
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
mTime1 := gtime.TimestampMilli()
rows, err := link.QueryContext(ctx, sql, args...)
mTime2 := gtime.TimestampMilli()
if err == nil {
result, err = c.convertRowsToResult(ctx, rows)
// Link execution.
var out *DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Link: link,
Sql: sql,
Args: args,
Stmt: nil,
Type: DoCommitTypeQueryContext,
})
if err != nil {
return nil, err
}
sqlObj := &Sql{
Sql: sql,
Type: sqlTypeQueryContext,
Args: args,
Format: FormatSqlWithArgs(sql, args),
Error: err,
Start: mTime1,
End: mTime2,
Group: c.db.GetGroup(),
IsTransaction: link.IsTransaction(),
RowsAffected: int64(result.Len()),
}
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(ctx, sqlObj)
}
if err == nil {
return result, nil
} else {
err = formatError(err, sql, args...)
if out != nil {
result, err = c.RowsToResult(ctx, out.Rows)
return result, err
}
return nil, err
}
@ -114,49 +103,96 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
defer cancelFunc()
}
// Link execution.
// Sql filtering.
sql, args = formatSql(sql, args)
sql, args, err = c.db.DoCommit(ctx, link, sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
// Link execution.
var out *DoCommitOutput
out, err = c.db.DoCommit(ctx, DoCommitInput{
Link: link,
Sql: sql,
Args: args,
Stmt: nil,
Type: DoCommitTypeExecContext,
})
if out != nil {
return out.Result, err
}
return nil, err
}
mTime1 := gtime.TimestampMilli()
if !c.db.GetDryRun() {
result, err = link.ExecContext(ctx, sql, args...)
} else {
result = new(SqlResult)
}
mTime2 := gtime.TimestampMilli()
var rowsAffected int64
if err == nil {
rowsAffected, err = result.RowsAffected()
}
sqlObj := &Sql{
Sql: sql,
Type: sqlTypeExecContext,
Args: args,
Format: FormatSqlWithArgs(sql, args),
Error: err,
Start: mTime1,
End: mTime2,
Group: c.db.GetGroup(),
IsTransaction: link.IsTransaction(),
RowsAffected: rowsAffected,
// DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver.
// The parameter `link` specifies the current database connection operation object. You can modify the sql
// string `sql` and its arguments `args` as you wish before they're committed to driver.
func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return sql, args, nil
}
// DoCommit commits current sql and arguments to underlying sql driver.
func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (*DoCommitOutput, error) {
var (
err error
cancelFuncForTimeout context.CancelFunc
out = &DoCommitOutput{}
timestampMilli1 = gtime.TimestampMilli()
)
switch in.Type {
case DoCommitTypeExecContext:
if c.db.GetDryRun() {
out.Result = new(SqlResult)
} else {
out.Result, err = in.Link.ExecContext(ctx, in.Sql, in.Args...)
}
case DoCommitTypeQueryContext:
out.Rows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...)
case DoCommitTypeStmtExecContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
defer cancelFuncForTimeout()
if c.db.GetDryRun() {
out.Result = new(SqlResult)
} else {
out.Result, err = in.Stmt.ExecContext(ctx, in.Args...)
}
case DoCommitTypeStmtQueryContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
out.Rows, err = in.Stmt.QueryContext(ctx, in.Args...)
case DoCommitTypeStmtQueryRowContext:
ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
out.Row = in.Stmt.QueryRowContext(ctx, in.Args...)
default:
panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid DoCommitType "%s"`, in.Type))
}
var (
timestampMilli2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: in.Sql,
Type: in.Type,
Args: in.Args,
Format: FormatSqlWithArgs(in.Sql, in.Args),
Error: err,
Start: timestampMilli1,
End: timestampMilli2,
Group: c.db.GetGroup(),
IsTransaction: in.Link.IsTransaction(),
}
)
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(ctx, sqlObj)
}
return result, formatError(err, sql, args...)
}
// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver.
// The parameter `link` specifies the current database connection operation object. You can modify the sql
// string `sql` and its arguments `args` as you wish before they're committed to driver.
func (c *Core) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return sql, args, nil
return out, formatError(err, in.Sql, in.Args...)
}
// Prepare creates a prepared statement for later queries or executions.
@ -239,8 +275,8 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, err
}, err
}
// convertRowsToResult converts underlying data record type sql.Rows to Result type.
func (c *Core) convertRowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
// RowsToResult converts underlying data record type sql.Rows to Result type.
func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
if rows == nil {
return nil, nil
}

View File

@ -83,10 +83,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
// DoFilter deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs)
}()
var index int
// Convert placeholder char '?' to string "@px".

View File

@ -87,9 +87,9 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
return "`", "`"
}
// DoCommit handles the sql before posts it to database.
func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoCommit(ctx, link, sql, args)
// DoFilter handles the sql before posts it to database.
func (d *DriverMysql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoFilter(ctx, link, sql, args)
}
// Tables retrieves and returns the tables of current schema.

View File

@ -86,10 +86,10 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
// DoFilter deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs)
}()
var index int

View File

@ -37,7 +37,7 @@ func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) {
}, nil
}
// Open creates and returns a underlying sql.DB object for pgsql.
// Open creates and returns an underlying sql.DB object for pgsql.
func (d *DriverPgsql) Open(config *ConfigNode) (db *sql.DB, err error) {
var (
source string
@ -85,10 +85,10 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
// DoFilter deals with the sql string before commits it to underlying sql driver.
func (d *DriverPgsql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
defer func() {
newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs)
}()
var index int

View File

@ -73,9 +73,9 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
return "`", "`"
}
// DoCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoCommit(ctx, link, sql, args)
// DoFilter deals with the sql string before commits it to underlying sql driver.
func (d *DriverSqlite) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
return d.Core.DoFilter(ctx, link, sql, args)
}
// Tables retrieves and returns the tables of current schema.

View File

@ -44,6 +44,9 @@ func (r *SqlResult) MustGetInsertId() int64 {
// driver may support this.
// Also, See sql.Result.
func (r *SqlResult) RowsAffected() (int64, error) {
if r.result == nil {
return 0, nil
}
if r.affected > 0 {
return r.affected, nil
}

View File

@ -9,10 +9,6 @@ package gdb
import (
"context"
"database/sql"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gtime"
)
// Stmt is a prepared statement.
@ -31,65 +27,18 @@ type Stmt struct {
sql string
}
const (
stmtTypeExecContext = "Statement.ExecContext"
stmtTypeQueryContext = "Statement.QueryContext"
stmtTypeQueryRowContext = "Statement.QueryRowContext"
)
// doStmtCommit commits statement according to given `stmtType`.
func (s *Stmt) doStmtCommit(ctx context.Context, stmtType string, args ...interface{}) (result interface{}, err error) {
var (
cancelFuncForTimeout context.CancelFunc
timestampMilli1 = gtime.TimestampMilli()
)
switch stmtType {
case stmtTypeExecContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
defer cancelFuncForTimeout()
result, err = s.Stmt.ExecContext(ctx, args...)
case stmtTypeQueryContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
result, err = s.Stmt.QueryContext(ctx, args...)
case stmtTypeQueryRowContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
result = s.Stmt.QueryRowContext(ctx, args...)
default:
panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid stmtType: %s`, stmtType))
}
var (
timestampMilli2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: s.sql,
Type: stmtType,
Args: args,
Format: FormatSqlWithArgs(s.sql, args),
Error: err,
Start: timestampMilli1,
End: timestampMilli2,
Group: s.core.db.GetGroup(),
IsTransaction: s.link.IsTransaction(),
}
)
// Tracing and logging.
s.core.addSqlToTracing(ctx, sqlObj)
if s.core.db.GetDebug() {
s.core.writeSqlToLogger(ctx, sqlObj)
}
return result, err
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
result, err := s.doStmtCommit(ctx, stmtTypeExecContext, args...)
if result != nil {
return result.(sql.Result), err
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: DoCommitTypeStmtExecContext,
})
if out != nil {
return out.Result, err
}
return nil, err
}
@ -97,9 +46,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
result, err := s.doStmtCommit(ctx, stmtTypeQueryContext, args...)
if result != nil {
return result.(*sql.Rows), err
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: DoCommitTypeStmtQueryContext,
})
if out != nil {
return out.Rows, err
}
return nil, err
}
@ -111,9 +66,18 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
result, _ := s.doStmtCommit(ctx, stmtTypeQueryRowContext, args...)
if result != nil {
return result.(*sql.Row)
out, err := s.core.db.DoCommit(ctx, DoCommitInput{
Stmt: s.Stmt,
Link: s.link,
Sql: s.sql,
Args: args,
Type: DoCommitTypeStmtQueryRowContext,
})
if err != nil {
panic(err)
}
if out != nil {
return out.Row
}
return nil
}

View File

@ -41,11 +41,11 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
}, nil
}
// DoCommit handles the sql before posts it to database.
// DoFilter handles the sql before posts it to database.
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
func (d *MyDriver) DoFilter(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
latestSqlString.Set(sql)
return d.DriverMysql.DoCommit(ctx, link, sql, args)
return d.DriverMysql.DoFilter(ctx, link, sql, args)
}
func init() {