mirror of
https://gitee.com/johng/gf.git
synced 2024-12-01 19:57:40 +08:00
gdb: add support for slice argument in where statement
This commit is contained in:
parent
c5e9686a95
commit
7a8bd96edc
@ -410,7 +410,8 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...)
|
||||
newWhere, newArgs := formatCondition(condition, params)
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
@ -424,7 +425,8 @@ func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...)
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...)
|
||||
}
|
||||
|
||||
// 获得缓存对象
|
||||
|
@ -7,6 +7,7 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -14,9 +15,11 @@ import (
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
"gitee.com/johng/gf/g/util/gstr"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 将数据查询的列表数据*sql.Rows转换为Result类型
|
||||
@ -55,30 +58,61 @@ func rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
}
|
||||
|
||||
// 格式化SQL查询条件
|
||||
func formatCondition(condition interface{}) (where string) {
|
||||
if reflect.ValueOf(condition).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(condition).MapKeys()
|
||||
vs := reflect.ValueOf(condition)
|
||||
func formatCondition(where interface{}, args []interface{}) (string, []interface{}) {
|
||||
// 条件字符串处理
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
if reflect.ValueOf(where).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(where).MapKeys()
|
||||
vs := reflect.ValueOf(where)
|
||||
for _, k := range ks {
|
||||
key := gconv.String(k.Interface())
|
||||
value := gconv.String(vs.MapIndex(k).Interface())
|
||||
isNum := gstr.IsNumeric(value)
|
||||
if len(where) > 0 {
|
||||
where += " AND "
|
||||
if buffer.Len() > 0 {
|
||||
buffer.WriteString(" AND ")
|
||||
}
|
||||
if isNum || value == "?" {
|
||||
where += key + "=" + value
|
||||
if gstr.IsNumeric(value) || value == "?" {
|
||||
buffer.WriteString(key + "=" + value)
|
||||
} else {
|
||||
where += key + "='" + value + "'"
|
||||
buffer.WriteString(key + "='" + value + "'")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
where += gconv.String(condition)
|
||||
buffer.Write(gconv.Bytes(where))
|
||||
}
|
||||
if len(where) == 0 {
|
||||
where = "1"
|
||||
if buffer.Len() == 0 {
|
||||
buffer.WriteString("1")
|
||||
}
|
||||
return
|
||||
// 查询条件处理
|
||||
newWhere := buffer.String()
|
||||
newArgs := make([]interface{}, 0)
|
||||
if len(args) > 0 {
|
||||
for index, arg := range args {
|
||||
rv := reflect.ValueOf(arg)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Slice: fallthrough
|
||||
case reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
newArgs = append(newArgs, rv.Index(i).Interface())
|
||||
}
|
||||
counter := 0
|
||||
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
|
||||
counter++
|
||||
if counter == index + 1 {
|
||||
return "?" + strings.Repeat(",?", rv.Len() - 1)
|
||||
}
|
||||
return s
|
||||
})
|
||||
default:
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return newWhere, newArgs
|
||||
}
|
||||
|
||||
// 打印SQL对象(仅在debug=true时有效)
|
||||
|
@ -108,8 +108,9 @@ func (md *Model) Filter() (*Model) {
|
||||
|
||||
// 链式操作,condition,支持string & gdb.Map
|
||||
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where = formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
md.where = newWhere
|
||||
md.whereArgs = append(md.whereArgs, newArgs...)
|
||||
// 支持 Where("uid", 1)这种格式
|
||||
if len(args) == 1 && strings.Index(md.where , "?") < 0 {
|
||||
md.where += "=?"
|
||||
@ -119,15 +120,17 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
|
||||
|
||||
// 链式操作,添加AND条件到Where中
|
||||
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " AND " + formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
md.where += " AND " + newWhere
|
||||
md.whereArgs = append(md.whereArgs, newArgs...)
|
||||
return md
|
||||
}
|
||||
|
||||
// 链式操作,添加OR条件到Where中
|
||||
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " OR " + formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
md.where += " OR " + newWhere
|
||||
md.whereArgs = append(md.whereArgs, newArgs...)
|
||||
return md
|
||||
}
|
||||
|
||||
|
@ -168,6 +168,25 @@ func TestModel_GroupBy(t *testing.T) {
|
||||
gtest.Assert(result[0]["nickname"].String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Where1(t *testing.T) {
|
||||
result, err := db.Table("user").Where("id IN(?)", g.Slice{1,3}).OrderBy("id ASC").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 2)
|
||||
gtest.Assert(result[0]["id"].Int(), 1)
|
||||
gtest.Assert(result[1]["id"].Int(), 3)
|
||||
}
|
||||
|
||||
func TestModel_Where2(t *testing.T) {
|
||||
result, err := db.Table("user").Where("nickname=? AND id IN(?)", "T3", g.Slice{1,3}).OrderBy("id ASC").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 1)
|
||||
gtest.Assert(result[0]["id"].Int(), 3)
|
||||
}
|
||||
|
||||
func TestModel_Delete(t *testing.T) {
|
||||
result, err := db.Table("user").Delete()
|
||||
if err != nil {
|
||||
|
@ -10,7 +10,7 @@ func main() {
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
r, _ := db.Table("test").Where("id IN (?,?)", 1,2).All()
|
||||
r, _ := db.Table("test").Where("id IN (?)", []interface{}{1, 2}).All()
|
||||
if r != nil {
|
||||
fmt.Println(r.ToList())
|
||||
}
|
||||
|
@ -3,13 +3,20 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
||||
|
||||
func main() {
|
||||
query := "select * from user"
|
||||
q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
fmt.Println(err)
|
||||
fmt.Println(q)
|
||||
newWhere := "?????"
|
||||
counter := 0
|
||||
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
|
||||
counter++
|
||||
if counter == 4 {
|
||||
return "?" + strings.Repeat(",!", 5 - 1)
|
||||
}
|
||||
return s
|
||||
})
|
||||
fmt.Println(newWhere)
|
||||
}
|
Loading…
Reference in New Issue
Block a user