gdb: add support for slice argument in where statement

This commit is contained in:
John 2018-12-17 10:52:44 +08:00
parent c5e9686a95
commit 7a8bd96edc
6 changed files with 92 additions and 27 deletions

View File

@ -410,7 +410,8 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...)
newWhere, newArgs := formatCondition(condition, params)
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...)
}
// CURD操作:删除数据
@ -424,7 +425,8 @@ func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{
// CURD操作:删除数据
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...)
newWhere, newArgs := formatCondition(condition, args)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...)
}
// 获得缓存对象

View File

@ -7,6 +7,7 @@
package gdb
import (
"bytes"
"database/sql"
"errors"
"fmt"
@ -14,9 +15,11 @@ import (
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
"gitee.com/johng/gf/g/util/gstr"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"reflect"
"strings"
)
// 将数据查询的列表数据*sql.Rows转换为Result类型
@ -55,30 +58,61 @@ func rowsToResult(rows *sql.Rows) (Result, error) {
}
// 格式化SQL查询条件
func formatCondition(condition interface{}) (where string) {
if reflect.ValueOf(condition).Kind() == reflect.Map {
ks := reflect.ValueOf(condition).MapKeys()
vs := reflect.ValueOf(condition)
func formatCondition(where interface{}, args []interface{}) (string, []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
if reflect.ValueOf(where).Kind() == reflect.Map {
ks := reflect.ValueOf(where).MapKeys()
vs := reflect.ValueOf(where)
for _, k := range ks {
key := gconv.String(k.Interface())
value := gconv.String(vs.MapIndex(k).Interface())
isNum := gstr.IsNumeric(value)
if len(where) > 0 {
where += " AND "
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}
if isNum || value == "?" {
where += key + "=" + value
if gstr.IsNumeric(value) || value == "?" {
buffer.WriteString(key + "=" + value)
} else {
where += key + "='" + value + "'"
buffer.WriteString(key + "='" + value + "'")
}
}
} else {
where += gconv.String(condition)
buffer.Write(gconv.Bytes(where))
}
if len(where) == 0 {
where = "1"
if buffer.Len() == 0 {
buffer.WriteString("1")
}
return
// 查询条件处理
newWhere := buffer.String()
newArgs := make([]interface{}, 0)
if len(args) > 0 {
for index, arg := range args {
rv := reflect.ValueOf(arg)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Slice: fallthrough
case reflect.Array:
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
}
counter := 0
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
counter++
if counter == index + 1 {
return "?" + strings.Repeat(",?", rv.Len() - 1)
}
return s
})
default:
newArgs = append(newArgs, arg)
}
}
}
return newWhere, newArgs
}
// 打印SQL对象(仅在debug=true时有效)

View File

@ -108,8 +108,9 @@ func (md *Model) Filter() (*Model) {
// 链式操作condition支持string & gdb.Map
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
md.where = formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
newWhere, newArgs := formatCondition(where, args)
md.where = newWhere
md.whereArgs = append(md.whereArgs, newArgs...)
// 支持 Where("uid", 1)这种格式
if len(args) == 1 && strings.Index(md.where , "?") < 0 {
md.where += "=?"
@ -119,15 +120,17 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
// 链式操作添加AND条件到Where中
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
md.where += " AND " + formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
newWhere, newArgs := formatCondition(where, args)
md.where += " AND " + newWhere
md.whereArgs = append(md.whereArgs, newArgs...)
return md
}
// 链式操作添加OR条件到Where中
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
md.where += " OR " + formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
newWhere, newArgs := formatCondition(where, args)
md.where += " OR " + newWhere
md.whereArgs = append(md.whereArgs, newArgs...)
return md
}

View File

@ -168,6 +168,25 @@ func TestModel_GroupBy(t *testing.T) {
gtest.Assert(result[0]["nickname"].String(), "T111")
}
func TestModel_Where1(t *testing.T) {
result, err := db.Table("user").Where("id IN(?)", g.Slice{1,3}).OrderBy("id ASC").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 2)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 3)
}
func TestModel_Where2(t *testing.T) {
result, err := db.Table("user").Where("nickname=? AND id IN(?)", "T3", g.Slice{1,3}).OrderBy("id ASC").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 1)
gtest.Assert(result[0]["id"].Int(), 3)
}
func TestModel_Delete(t *testing.T) {
result, err := db.Table("user").Delete()
if err != nil {

View File

@ -10,7 +10,7 @@ func main() {
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)
r, _ := db.Table("test").Where("id IN (?,?)", 1,2).All()
r, _ := db.Table("test").Where("id IN (?)", []interface{}{1, 2}).All()
if r != nil {
fmt.Println(r.ToList())
}

View File

@ -3,13 +3,20 @@ package main
import (
"fmt"
"gitee.com/johng/gf/g/util/gregex"
"strings"
)
func main() {
query := "select * from user"
q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
fmt.Println(err)
fmt.Println(q)
newWhere := "?????"
counter := 0
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
counter++
if counter == 4 {
return "?" + strings.Repeat(",!", 5 - 1)
}
return s
})
fmt.Println(newWhere)
}