gf/g/database/gdb/gdb_base.go

469 lines
16 KiB
Go
Raw Normal View History

2017-12-29 16:03:30 +08:00
// Copyright 2017 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
//
2017-12-31 18:19:58 +08:00
2017-11-23 10:21:28 +08:00
package gdb
import (
"database/sql"
2018-12-14 18:35:51 +08:00
"errors"
"fmt"
"gitee.com/johng/gf/g/os/gcache"
"gitee.com/johng/gf/g/os/gtime"
2018-12-14 18:35:51 +08:00
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
2018-12-14 18:35:51 +08:00
"reflect"
"strings"
2017-11-23 10:21:28 +08:00
)
const (
gDEFAULT_DEBUG_SQL_LENGTH = 1000 // 默认调试模式下记录的SQL条数
)
// 获取已经执行的SQL列表(仅在debug=true时有效)
func (bs *dbBase) GetQueriedSqls() []*Sql {
if bs.sqls == nil {
return nil
}
sqls := make([]*Sql, 0)
bs.sqls.Prev()
bs.sqls.RLockIteratorPrev(func(value interface{}) bool {
if value == nil {
return false
}
sqls = append(sqls, value.(*Sql))
return true
})
return sqls
}
// 打印已经执行的SQL列表(仅在debug=true时有效)
func (bs *dbBase) PrintQueriedSqls() {
sqls := bs.GetQueriedSqls()
for k, v := range sqls {
fmt.Println(len(sqls) - k, ":")
fmt.Println(" Sql :", v.Sql)
fmt.Println(" Args :", v.Args)
fmt.Println(" Error:", v.Error)
fmt.Println(" Start:", gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"))
fmt.Println(" End :", gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"))
fmt.Println(" Cost :", v.End - v.Start, "ms")
}
}
2017-11-23 10:21:28 +08:00
// 数据库sql查询操作主要执行查询
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := bs.db.Slave()
if err != nil {
return nil,err
}
return bs.db.doQuery(link, query, args...)
}
// 数据库sql查询操作主要执行查询
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
rows, err = link.Query(query, args...)
mTime2 := gtime.Millisecond()
s := &Sql {
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
bs.sqls.Put(s)
printSql(s)
} else {
rows, err = link.Query(query, args ...)
}
if err == nil {
return rows, nil
} else {
err = formatError(err, query, args...)
}
return nil, err
2017-11-23 10:21:28 +08:00
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil,err
}
return bs.db.doExec(link, query, args...)
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
result, err = link.Exec(query, args ...)
mTime2 := gtime.Millisecond()
s := &Sql{
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
bs.sqls.Put(s)
printSql(s)
} else {
result, err = link.Exec(query, args ...)
}
return result, formatError(err, query, args...)
2017-11-23 10:21:28 +08:00
}
2018-12-14 18:35:51 +08:00
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上
func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
2018-12-14 18:35:51 +08:00
err := (error)(nil)
link := (dbLink)(nil)
2018-12-14 18:35:51 +08:00
if len(execOnMaster) > 0 && execOnMaster[0] {
if link, err = bs.db.Master(); err != nil {
2018-12-14 18:35:51 +08:00
return nil, err
}
} else {
if link, err = bs.db.Slave(); err != nil {
2018-12-14 18:35:51 +08:00
return nil, err
2017-11-23 10:21:28 +08:00
}
}
return bs.db.doPrepare(link, query)
}
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
return link.Prepare(query)
2017-11-23 10:21:28 +08:00
}
// 数据库查询,获取查询结果集,以列表结构返回
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := bs.Query(query, args ...)
2017-11-23 10:21:28 +08:00
if err != nil || rows == nil {
return nil, err
}
defer rows.Close()
2018-12-14 18:35:51 +08:00
return rowsToResult(rows)
2017-11-23 10:21:28 +08:00
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
list, err := bs.GetAll(query, args ...)
2017-11-23 10:21:28 +08:00
if err != nil {
return nil, err
}
2018-04-18 16:32:54 +08:00
if len(list) > 0 {
return list[0], nil
}
return nil, nil
2017-11-23 10:21:28 +08:00
}
// 数据库查询获取查询结果记录自动映射数据到给定的struct对象中
func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := bs.GetOne(query, args...)
if err != nil {
return err
}
return one.ToStruct(obj)
}
2017-11-23 10:21:28 +08:00
// 数据库查询,获取查询字段值
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
one, err := bs.GetOne(query, args ...)
2017-11-23 10:21:28 +08:00
if err != nil {
2018-03-12 15:12:38 +08:00
return nil, err
2017-11-23 10:21:28 +08:00
}
2017-12-20 12:05:36 +08:00
for _, v := range one {
2017-11-23 10:21:28 +08:00
return v, nil
}
2018-03-12 15:12:38 +08:00
return nil, nil
2017-11-23 10:21:28 +08:00
}
// 数据库查询,获取查询数量
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := bs.GetValue(query, args ...)
if err != nil {
return 0, err
}
return value.Int(), nil
2018-04-20 23:23:42 +08:00
}
2017-11-23 10:21:28 +08:00
// ping一下判断或保持数据库链接(master)
func (bs *dbBase) PingMaster() error {
if master, err := bs.db.Master(); err != nil {
return err
} else {
return master.Ping()
}
2017-11-23 10:21:28 +08:00
}
// ping一下判断或保持数据库链接(slave)
func (bs *dbBase) PingSlave() error {
if slave, err := bs.db.Slave(); err != nil {
return err
} else {
return slave.Ping()
}
2017-11-23 10:21:28 +08:00
}
// 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略
// 只有在tx.Commit/tx.Rollback时链接会自动Close
func (bs *dbBase) Begin() (*TX, error) {
if master, err := bs.db.Master(); err != nil {
2018-03-12 11:46:12 +08:00
return nil, err
} else {
if tx, err := master.Begin(); err == nil {
2018-12-14 18:35:51 +08:00
return &TX {
db : bs.db,
tx : tx,
master : master,
}, nil
} else {
return nil, err
}
2018-03-12 11:46:12 +08:00
}
2017-11-23 10:21:28 +08:00
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) Save(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_SAVE)
}
2017-11-23 10:21:28 +08:00
// insert、replace, save ignore操作
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) {
var fields []string
2017-11-23 10:21:28 +08:00
var values []string
var params []interface{}
charl, charr := bs.db.getChars()
2017-12-20 12:05:36 +08:00
for k, v := range data {
2018-12-14 18:35:51 +08:00
fields = append(fields, charl + k + charr)
2017-11-23 10:21:28 +08:00
values = append(values, "?")
params = append(params, v)
2017-11-23 10:21:28 +08:00
}
2018-12-14 18:35:51 +08:00
operation := getInsertOperationByOption(option)
2017-11-23 10:21:28 +08:00
updatestr := ""
if option == OPTION_SAVE {
var updates []string
2017-12-20 12:05:36 +08:00
for k, _ := range data {
2018-08-13 18:55:28 +08:00
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
2018-12-14 18:35:51 +08:00
charl, k, charr,
charl, k, charr,
2018-08-13 18:55:28 +08:00
),
)
2017-11-23 10:21:28 +08:00
}
2018-08-13 18:55:28 +08:00
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
2017-11-23 10:21:28 +08:00
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updatestr),
params...)
2017-11-23 10:21:28 +08:00
}
// CURD操作:批量数据指定批次量写入
func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT)
2017-11-23 10:21:28 +08:00
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE)
2017-11-23 10:21:28 +08:00
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE)
2017-11-23 10:21:28 +08:00
}
// 批量写入数据
func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) {
2017-11-23 10:21:28 +08:00
var keys []string
var values []string
var bvalues []string
var params []interface{}
// 判断长度
if len(list) < 1 {
2018-01-08 14:15:46 +08:00
return result, errors.New("empty data list")
2017-11-23 10:21:28 +08:00
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
return
}
}
2017-11-23 10:21:28 +08:00
// 首先获取字段名称及记录长度
2017-12-20 12:05:36 +08:00
for k, _ := range list[0] {
2017-11-23 10:21:28 +08:00
keys = append(keys, k)
values = append(values, "?")
}
charl, charr := bs.db.getChars()
2018-12-14 18:35:51 +08:00
keyStr := charl + strings.Join(keys, charl + "," + charr) + charr
2018-06-30 23:12:37 +08:00
valueHolderStr := "(" + strings.Join(values, ",") + ")"
2017-11-23 10:21:28 +08:00
// 操作判断
2018-12-14 18:35:51 +08:00
operation := getInsertOperationByOption(option)
2017-11-23 10:21:28 +08:00
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
2018-08-13 18:55:28 +08:00
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
2018-12-14 18:35:51 +08:00
charl, k, charr,
charl, k, charr,
2018-08-13 18:55:28 +08:00
),
)
2017-11-23 10:21:28 +08:00
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < len(list); i++ {
2017-11-23 10:21:28 +08:00
for _, k := range keys {
params = append(params, list[i][k])
2017-11-23 10:21:28 +08:00
}
2018-06-30 23:12:37 +08:00
bvalues = append(bvalues, valueHolderStr)
2017-11-23 10:21:28 +08:00
if len(bvalues) == batch {
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
2018-10-29 21:05:15 +08:00
operation, table, keyStr, strings.Join(bvalues, ","),
2018-06-30 23:12:37 +08:00
updatestr),
params...)
2017-11-23 10:21:28 +08:00
if err != nil {
2018-01-08 14:15:46 +08:00
return result, err
2017-11-23 10:21:28 +08:00
}
2018-01-08 14:15:46 +08:00
result = r
2018-06-30 23:12:37 +08:00
params = params[:0]
2017-11-23 10:21:28 +08:00
bvalues = bvalues[:0]
}
}
// 处理最后不构成指定批量的数据
2018-01-08 14:15:46 +08:00
if len(bvalues) > 0 {
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
2018-10-29 21:05:15 +08:00
operation, table, keyStr, strings.Join(bvalues, ","),
2018-06-30 23:12:37 +08:00
updatestr),
params...)
2017-11-23 10:21:28 +08:00
if err != nil {
2018-01-08 14:15:46 +08:00
return result, err
2017-11-23 10:21:28 +08:00
}
2018-01-08 14:15:46 +08:00
result = r
2017-11-23 10:21:28 +08:00
}
2018-01-08 14:15:46 +08:00
return result, nil
2017-11-23 10:21:28 +08:00
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doUpdate(link, table, data, condition, args ...)
2017-11-23 10:21:28 +08:00
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) {
params := ([]interface{})(nil)
updates := ""
charl, charr := bs.db.getChars()
2018-12-14 18:35:51 +08:00
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
keys := refValue.MapKeys()
for _, k := range keys {
2018-12-14 18:35:51 +08:00
fields = append(fields, fmt.Sprintf("%s%s%s=?", charl, k, charr))
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
}
updates = strings.Join(fields, ",")
} else {
updates = gconv.String(data)
2017-11-23 10:21:28 +08:00
}
for _, v := range args {
params = append(params, gconv.String(v))
2017-11-23 10:21:28 +08:00
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
newWhere, newArgs := formatCondition(condition, params)
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...)
2017-11-23 10:21:28 +08:00
}
// CURD操作:删除数据
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doDelete(link, table, condition, args ...)
}
// CURD操作:删除数据
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatCondition(condition, args)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...)
2018-12-14 18:35:51 +08:00
}
// 获得缓存对象
func (bs *dbBase) getCache() *gcache.Cache {
return bs.cache
2018-12-14 18:35:51 +08:00
}
2018-12-16 22:22:07 +08:00
// 将map的数据按照fields进行过滤只保留与表字段同名的数据
func (bs *dbBase) filterFields(table string, data map[string]interface{}) map[string]interface{} {
if fields, err := bs.db.getTableFields(table); err == nil {
for k, _ := range data {
if _, ok := fields[k]; !ok {
delete(data, k)
}
}
}
return data
}
2018-12-16 22:27:04 +08:00
// 获得指定表表的数据结构构造成map哈希表返回其中键名为表字段名称键值暂无用途(默认为字段数据类型).
2018-12-16 22:22:07 +08:00
func (bs *dbBase) getTableFields(table string) (fields map[string]string, err error) {
2018-12-16 22:27:04 +08:00
// 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署)
2018-12-16 22:22:07 +08:00
v := bs.cache.GetOrSetFunc("table_fields_" + table, func() interface{} {
result := (Result)(nil)
charl, charr := bs.db.getChars()
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charl, table, charr))
if err != nil {
return nil
}
fields = make(map[string]string)
for _, m := range result {
fields[m["Field"].String()] = m["Type"].String()
}
return fields
}, 0)
if err == nil {
fields = v.(map[string]string)
}
return
}