mirror of
https://gitee.com/johng/gf.git
synced 2024-12-02 04:07:47 +08:00
add nested transaction feature for package gdb
This commit is contained in:
parent
5856f74d83
commit
d7eb1cca07
@ -9,6 +9,7 @@ package gdb
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/util/gconv"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
@ -16,21 +17,101 @@ import (
|
||||
|
||||
// TX is the struct for transaction management.
|
||||
type TX struct {
|
||||
db DB
|
||||
tx *sql.Tx
|
||||
master *sql.DB
|
||||
db DB // db is the current gdb database manager.
|
||||
tx *sql.Tx // tx is the raw and underlying transaction manager.
|
||||
master *sql.DB // master is the raw and underlying database manager.
|
||||
transactionCount int // transactionCount marks the times that Begins.
|
||||
}
|
||||
|
||||
// Commit commits the transaction.
|
||||
const (
|
||||
transactionPointerPrefix = "transaction"
|
||||
)
|
||||
|
||||
// Commit commits current transaction.
|
||||
// Note that it releases previous saved transaction point if it's in a nested transaction procedure,
|
||||
// or else it commits the hole transaction.
|
||||
func (tx *TX) Commit() error {
|
||||
if tx.transactionCount > 0 {
|
||||
tx.transactionCount--
|
||||
_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKey())
|
||||
return err
|
||||
}
|
||||
return tx.tx.Commit()
|
||||
}
|
||||
|
||||
// Rollback aborts the transaction.
|
||||
// Rollback aborts current transaction.
|
||||
// Note that it aborts current transaction if it's in a nested transaction procedure,
|
||||
// or else it aborts the hole transaction.
|
||||
func (tx *TX) Rollback() error {
|
||||
if tx.transactionCount > 0 {
|
||||
tx.transactionCount--
|
||||
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKey())
|
||||
return err
|
||||
}
|
||||
return tx.tx.Rollback()
|
||||
}
|
||||
|
||||
// Begin starts a nested transaction procedure.
|
||||
func (tx *TX) Begin() error {
|
||||
_, err := tx.Exec("SAVEPOINT " + tx.transactionKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.transactionCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point.
|
||||
// The parameter `point` specifies the point name that will be saved to server.
|
||||
func (tx *TX) SavePoint(point string) error {
|
||||
_, err := tx.Exec("SAVEPOINT " + tx.db.QuoteWord(point))
|
||||
return err
|
||||
}
|
||||
|
||||
// RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction.
|
||||
// The parameter `point` specifies the point name that was saved previously.
|
||||
func (tx *TX) RollbackTo(point string) error {
|
||||
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.QuoteWord(point))
|
||||
return err
|
||||
}
|
||||
|
||||
// transactionKey forms and returns the transaction key at current save point.
|
||||
func (tx *TX) transactionKey() string {
|
||||
return tx.db.QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount))
|
||||
}
|
||||
|
||||
// Transaction wraps the transaction logic using function `f`.
|
||||
// It rollbacks the transaction and returns the error from function `f` if
|
||||
// it returns non-nil error. It commits the transaction and returns nil if
|
||||
// function `f` returns nil.
|
||||
//
|
||||
// Note that, you should not Commit or Rollback the transaction in function `f`
|
||||
// as it is automatically handled by this function.
|
||||
func (tx *TX) Transaction(f func(tx *TX) error) (err error) {
|
||||
err = tx.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("%v", e)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = e
|
||||
}
|
||||
} else {
|
||||
if e := tx.Commit(); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
}()
|
||||
err = f(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Query does query operation on transaction.
|
||||
// See Core.Query.
|
||||
func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
@ -221,7 +302,7 @@ func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Res
|
||||
return tx.Model(table).Data(list).Insert()
|
||||
}
|
||||
|
||||
// BatchInsert batch inserts data with ignore option.
|
||||
// BatchInsertIgnore batch inserts data with ignore option.
|
||||
// The parameter `list` must be type of slice of map or struct.
|
||||
func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
if len(batch) > 0 {
|
||||
|
@ -789,3 +789,137 @@ func Test_Transaction_Panic(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Transaction_Nested_Begin_Rollback_Commit(t *testing.T) {
|
||||
table := createTable()
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
tx, err := db.Begin()
|
||||
t.AssertNil(err)
|
||||
// tx begin.
|
||||
err = tx.Begin()
|
||||
t.AssertNil(err)
|
||||
// tx rollback.
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "user_1",
|
||||
"password": "pass_1",
|
||||
"nickname": "name_1",
|
||||
}).Insert()
|
||||
err = tx.Rollback()
|
||||
t.AssertNil(err)
|
||||
// tx commit.
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 2,
|
||||
"passport": "user_2",
|
||||
"password": "pass_2",
|
||||
"nickname": "name_2",
|
||||
}).Insert()
|
||||
err = tx.Commit()
|
||||
t.AssertNil(err)
|
||||
// check data.
|
||||
all, err := db.Model(table).All()
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(all), 1)
|
||||
t.Assert(all[0]["id"], 2)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Transaction_Nested_TX_Transaction(t *testing.T) {
|
||||
table := createTable()
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var err error
|
||||
err = db.Transaction(func(tx *gdb.TX) error {
|
||||
// commit
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "USER_1",
|
||||
"password": "PASS_1",
|
||||
"nickname": "NAME_1",
|
||||
"create_time": gtime.Now().String(),
|
||||
}).Insert()
|
||||
t.AssertNil(err)
|
||||
return err
|
||||
})
|
||||
t.AssertNil(err)
|
||||
return err
|
||||
})
|
||||
t.AssertNil(err)
|
||||
return err
|
||||
})
|
||||
t.AssertNil(err)
|
||||
return err
|
||||
})
|
||||
t.AssertNil(err)
|
||||
return err
|
||||
})
|
||||
t.AssertNil(err)
|
||||
// rollback
|
||||
err = tx.Transaction(func(tx *gdb.TX) error {
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 2,
|
||||
"passport": "USER_2",
|
||||
"password": "PASS_2",
|
||||
"nickname": "NAME_2",
|
||||
"create_time": gtime.Now().String(),
|
||||
}).Insert()
|
||||
t.AssertNil(err)
|
||||
panic("error")
|
||||
return err
|
||||
})
|
||||
t.AssertNE(err, nil)
|
||||
return nil
|
||||
})
|
||||
t.AssertNil(err)
|
||||
|
||||
all, err := db.Model(table).All()
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(all), 1)
|
||||
t.Assert(all[0]["id"], 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Transaction_Nested_SavePoint_RollbackTo(t *testing.T) {
|
||||
table := createTable()
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
tx, err := db.Begin()
|
||||
t.AssertNil(err)
|
||||
// tx save point.
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "user_1",
|
||||
"password": "pass_1",
|
||||
"nickname": "name_1",
|
||||
}).Insert()
|
||||
err = tx.SavePoint("MyPoint")
|
||||
t.AssertNil(err)
|
||||
_, err = tx.Model(table).Data(g.Map{
|
||||
"id": 2,
|
||||
"passport": "user_2",
|
||||
"password": "pass_2",
|
||||
"nickname": "name_2",
|
||||
}).Insert()
|
||||
// tx rollback to.
|
||||
err = tx.RollbackTo("MyPoint")
|
||||
t.AssertNil(err)
|
||||
// tx commit.
|
||||
err = tx.Commit()
|
||||
t.AssertNil(err)
|
||||
|
||||
// check data.
|
||||
all, err := db.Model(table).All()
|
||||
t.AssertNil(err)
|
||||
t.Assert(len(all), 1)
|
||||
t.Assert(all[0]["id"], 1)
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user