diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 67a1489af..a7490a351 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -214,6 +214,12 @@ type Link interface { Prepare(sql string) (*sql.Stmt, error) } +// Counter is the type for update count. +type Counter struct { + Field string + Value float64 +} + type ( // Value is the field value type. Value = *gvar.Var diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 58058204d..34aa29d8b 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -699,8 +699,35 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str dataMap = ConvertDataForTableRecord(data) ) for k, v := range dataMap { - fields = append(fields, c.DB.QuoteWord(k)+"=?") - params = append(params, v) + switch value := v.(type) { + case *Counter: + if value.Value != 0 { + column := c.DB.QuoteWord(value.Field) + var symbol string + if value.Value < 0 { + symbol = "-" + } else { + symbol = "+" + } + fields = append(fields, fmt.Sprintf("%s=%s%s?", column, column, symbol)) + params = append(params, value.Value) + } + case Counter: + if value.Value != 0 { + column := c.DB.QuoteWord(value.Field) + var symbol string + if value.Value < 0 { + symbol = "-" + } else { + symbol = "+" + } + fields = append(fields, fmt.Sprintf("%s=%s%s?", column, column, symbol)) + params = append(params, value.Value) + } + default: + fields = append(fields, c.DB.QuoteWord(k)+"=?") + params = append(params, v) + } } updates = strings.Join(fields, ",") default: diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 03de4496c..db0942071 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -127,6 +127,8 @@ func ConvertDataForTableRecord(value interface{}) map[string]interface{} { switch v.(type) { case time.Time, *time.Time, gtime.Time, *gtime.Time: continue + case Counter, *Counter: + continue default: // Use string conversion in default. if s, ok := v.(apiString); ok { diff --git a/database/gdb/gdb_z_mysql_method_test.go b/database/gdb/gdb_z_mysql_method_test.go index a0b5f819c..1929713c5 100644 --- a/database/gdb/gdb_z_mysql_method_test.go +++ b/database/gdb/gdb_z_mysql_method_test.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/encoding/gparser" + "github.com/gogf/gf/util/gconv" "testing" "time" @@ -1405,3 +1406,47 @@ func Test_Empty_Slice_Argument(t *testing.T) { t.Assert(len(result), 0) }) } + +// update counter test +func Test_DB_UpdateCounter(t *testing.T) { + tableName := "gf_update_counter_test" + defer dropTable(tableName) + _, err := db.Exec(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL, + views int(8) unsigned DEFAULT '0' NOT NULL , + updated_time int(10) unsigned DEFAULT '0' NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableName)) + if err != nil { + gtest.Fatal(err) + } + id := 1 + insertData := g.Map{ + "id": id, + "views": 0, + "updated_time": 0, + } + _, err = db.Insert(tableName, insertData) + if err != nil { + gtest.Fatal(err) + } + gtest.C(t, func(t *gtest.T) { + gdbCounter := &gdb.Counter{ + Field: "views", + Value: 1, + } + updateData := g.Map{ + "views": gdbCounter, + "updated_time": gtime.Now().Unix(), + } + result, err := db.Update(tableName, updateData, "id="+gconv.String(id)) + t.Assert(err, nil) + n, _ := result.RowsAffected() + t.Assert(n, 1) + one, err := db.Table(tableName).Where("id", id).One() + t.Assert(err, nil) + t.Assert(one["id"].Int(), 1) + t.Assert(one["views"].String(), "1") + }) +}