Merge pull request #1028 from arieslee/gdb-counter

add update counter method for package gdb.
This commit is contained in:
John Guo 2020-11-29 21:46:17 +08:00 committed by GitHub
commit 600c081801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 2 deletions

View File

@ -214,6 +214,12 @@ type Link interface {
Prepare(sql string) (*sql.Stmt, error) Prepare(sql string) (*sql.Stmt, error)
} }
// Counter is the type for update count.
type Counter struct {
Field string
Value float64
}
type ( type (
// Value is the field value type. // Value is the field value type.
Value = *gvar.Var Value = *gvar.Var

View File

@ -699,8 +699,35 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str
dataMap = ConvertDataForTableRecord(data) dataMap = ConvertDataForTableRecord(data)
) )
for k, v := range dataMap { for k, v := range dataMap {
fields = append(fields, c.DB.QuoteWord(k)+"=?") switch value := v.(type) {
params = append(params, v) case *Counter:
if value.Value != 0 {
column := c.DB.QuoteWord(value.Field)
var symbol string
if value.Value < 0 {
symbol = "-"
} else {
symbol = "+"
}
fields = append(fields, fmt.Sprintf("%s=%s%s?", column, column, symbol))
params = append(params, value.Value)
}
case Counter:
if value.Value != 0 {
column := c.DB.QuoteWord(value.Field)
var symbol string
if value.Value < 0 {
symbol = "-"
} else {
symbol = "+"
}
fields = append(fields, fmt.Sprintf("%s=%s%s?", column, column, symbol))
params = append(params, value.Value)
}
default:
fields = append(fields, c.DB.QuoteWord(k)+"=?")
params = append(params, v)
}
} }
updates = strings.Join(fields, ",") updates = strings.Join(fields, ",")
default: default:

View File

@ -127,6 +127,8 @@ func ConvertDataForTableRecord(value interface{}) map[string]interface{} {
switch v.(type) { switch v.(type) {
case time.Time, *time.Time, gtime.Time, *gtime.Time: case time.Time, *time.Time, gtime.Time, *gtime.Time:
continue continue
case Counter, *Counter:
continue
default: default:
// Use string conversion in default. // Use string conversion in default.
if s, ok := v.(apiString); ok { if s, ok := v.(apiString); ok {

View File

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/garray"
"github.com/gogf/gf/encoding/gparser" "github.com/gogf/gf/encoding/gparser"
"github.com/gogf/gf/util/gconv"
"testing" "testing"
"time" "time"
@ -1405,3 +1406,47 @@ func Test_Empty_Slice_Argument(t *testing.T) {
t.Assert(len(result), 0) t.Assert(len(result), 0)
}) })
} }
// update counter test
func Test_DB_UpdateCounter(t *testing.T) {
tableName := "gf_update_counter_test"
defer dropTable(tableName)
_, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(10) unsigned NOT NULL,
views int(8) unsigned DEFAULT '0' NOT NULL ,
updated_time int(10) unsigned DEFAULT '0' NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableName))
if err != nil {
gtest.Fatal(err)
}
id := 1
insertData := g.Map{
"id": id,
"views": 0,
"updated_time": 0,
}
_, err = db.Insert(tableName, insertData)
if err != nil {
gtest.Fatal(err)
}
gtest.C(t, func(t *gtest.T) {
gdbCounter := &gdb.Counter{
Field: "views",
Value: 1,
}
updateData := g.Map{
"views": gdbCounter,
"updated_time": gtime.Now().Unix(),
}
result, err := db.Update(tableName, updateData, "id="+gconv.String(id))
t.Assert(err, nil)
n, _ := result.RowsAffected()
t.Assert(n, 1)
one, err := db.Table(tableName).Where("id", id).One()
t.Assert(err, nil)
t.Assert(one["id"].Int(), 1)
t.Assert(one["views"].String(), "1")
})
}