improve argument handling for empty slice for package gdb

This commit is contained in:
john 2020-07-11 09:53:16 +08:00
parent 6712a33164
commit 4e027c1de3
7 changed files with 72 additions and 27 deletions

View File

@ -2,7 +2,6 @@ package main
import (
"fmt"
"github.com/gogf/gf/frame/g"
)
@ -11,11 +10,19 @@ func main() {
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)
r, e := db.Table("test").Order("id asc").All()
r, e := db.GetAll("SELECT * from `user` where id in(?)", g.Slice{})
if e != nil {
fmt.Println(e)
}
if r != nil {
fmt.Println(r.List())
fmt.Println(r)
}
return
//r, e := db.Table("user").Where("id in(?)", g.Slice{}).All()
//if e != nil {
// fmt.Println(e)
//}
//if r != nil {
// fmt.Println(r.List())
//}
}

View File

@ -29,7 +29,7 @@ type DB interface {
// Model creation.
// ===========================================================================
// Deprecated, use Model instead. The DB interface is designed not only for
// The DB interface is designed not only for
// relational databases but also for NoSQL databases in the future. The name
// "Table" is not proper for that purpose any more.
Table(table ...string) *Model

View File

@ -21,10 +21,6 @@ import (
"github.com/gogf/gf/util/gconv"
)
const (
gPATH_FILTER_KEY = "/database/gdb/gdb"
)
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Master() (*sql.DB, error) {
@ -777,8 +773,8 @@ func (c *Core) writeSqlToLogger(v *Sql) {
s := fmt.Sprintf("[%3d ms] %s", v.End-v.Start, v.Format)
if v.Error != nil {
s += "\nError: " + v.Error.Error()
c.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s)
c.logger.Error(s)
} else {
c.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s)
c.logger.Debug(s)
}
}

View File

@ -254,7 +254,7 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
// formatSql formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func formatSql(sql string, args []interface{}) (newQuery string, newArgs []interface{}) {
func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
sql = gstr.Trim(sql)
sql = gstr.Replace(sql, "\n", " ")
sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
@ -344,7 +344,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
}
// formatWhereInterfaces formats <where> as []interface{}.
// TODO []interface{} type support for parameter <where> does not completed yet.
// TODO supporting for parameter <where> with []interface{} type is not completed yet.
func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, newArgs []interface{}) []interface{} {
var str string
var array []interface{}
@ -457,9 +457,26 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i
newArgs = append(newArgs, arg)
continue
}
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
if rv.Len() == 0 {
// Empty slice argument, it converts the sql to a false sql.
// Eg:
// Query("select * from xxx where id in(?)", g.Slice{}) -> select * from xxx where 0=1
// Where("id in(?)", g.Slice{}) -> WHERE 0=1
if gstr.Contains(newSql, "?") {
whereKeyWord := " WHERE "
if p := gstr.PosI(newSql, whereKeyWord); p == -1 {
return "0=1", []interface{}{}
} else {
return gstr.SubStr(newSql, 0, p+len(whereKeyWord)) + "0=1", []interface{}{}
}
}
} else {
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
}
}
// If the '?' holder count equals the length of the slice,
// it does not implement the arguments splitting logic.
// Eg: db.Query("SELECT ?+?", g.Slice{1, 2})

View File

@ -1186,26 +1186,26 @@ func Test_Model_InnerJoin(t *testing.T) {
res, err := db.Table(table1).Where("id > ?", 5).Delete()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
n, err := res.RowsAffected()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(n, 5)
result, err := db.Table(table1+" u1").InnerJoin(table2+" u2", "u1.id = u2.id").OrderBy("u1.id").Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 5)
result, err = db.Table(table1+" u1").InnerJoin(table2+" u2", "u1.id = u2.id").Where("u1.id > ?", 1).OrderBy("u1.id").Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 4)
@ -1222,26 +1222,26 @@ func Test_Model_LeftJoin(t *testing.T) {
res, err := db.Table(table2).Where("id > ?", 3).Delete()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
n, err := res.RowsAffected()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
} else {
t.Assert(n, 7)
}
result, err := db.Table(table1+" u1").LeftJoin(table2+" u2", "u1.id = u2.id").Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 10)
result, err = db.Table(table1+" u1").LeftJoin(table2+" u2", "u1.id = u2.id").Where("u1.id > ? ", 2).Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 8)
@ -1258,26 +1258,36 @@ func Test_Model_RightJoin(t *testing.T) {
res, err := db.Table(table1).Where("id > ?", 3).Delete()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
n, err := res.RowsAffected()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(n, 7)
result, err := db.Table(table1+" u1").RightJoin(table2+" u2", "u1.id = u2.id").Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 10)
result, err = db.Table(table1+" u1").RightJoin(table2+" u2", "u1.id = u2.id").Where("u1.id > 2").Select()
if err != nil {
gtest.Fatal(err)
t.Fatal(err)
}
t.Assert(len(result), 1)
})
}
func Test_Empty_Slice_Argument(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
result, err := db.GetAll(fmt.Sprintf(`select * from %s where id in(?)`, table), g.Slice{})
t.Assert(err, nil)
t.Assert(len(result), 0)
})
}

View File

@ -2422,3 +2422,18 @@ func Test_Model_NullField(t *testing.T) {
t.Assert(user.Passport, data["passport"])
})
}
func Test_Model_Empty_Slice_Argument(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
result, err := db.Model(table).Where(`id`, g.Slice{}).All()
t.Assert(err, nil)
t.Assert(len(result), 0)
})
gtest.C(t, func(t *gtest.T) {
result, err := db.Model(table).Where(`id in(?)`, g.Slice{}).All()
t.Assert(err, nil)
t.Assert(len(result), 0)
})
}

View File

@ -91,7 +91,7 @@ func DB(name ...string) gdb.DB {
return gins.Database(name...)
}
// Deprecated, use Model instead.
// Table is alias of Model.
func Table(tables string, db ...string) *gdb.Model {
return DB(db...).Table(tables)
}