mirror of
https://gitee.com/johng/gf.git
synced 2024-11-29 18:57:44 +08:00
parent
849b104c31
commit
509cc47d3f
@ -9,21 +9,18 @@ package pgsql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// DoInsert inserts or updates data forF given table.
|
||||
func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) {
|
||||
switch option.InsertOption {
|
||||
case gdb.InsertOptionSave:
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeNotSupported,
|
||||
`Save operation is not supported by pgsql driver`,
|
||||
)
|
||||
|
||||
case gdb.InsertOptionReplace:
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeNotSupported,
|
||||
@ -50,3 +47,55 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
|
||||
}
|
||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||
}
|
||||
|
||||
// FormatUpsert returns SQL clause of type upsert for PgSQL.
|
||||
// For example: ON CONFLICT (id) DO UPDATE SET ...
|
||||
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
|
||||
if len(option.OnConflict) == 0 {
|
||||
return "", gerror.New("Please specify conflict columns")
|
||||
}
|
||||
|
||||
var onDuplicateStr string
|
||||
if option.OnDuplicateStr != "" {
|
||||
onDuplicateStr = option.OnDuplicateStr
|
||||
} else if len(option.OnDuplicateMap) > 0 {
|
||||
for k, v := range option.OnDuplicateMap {
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
switch v.(type) {
|
||||
case gdb.Raw, *gdb.Raw:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=%s",
|
||||
d.Core.QuoteWord(k),
|
||||
v,
|
||||
)
|
||||
default:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=EXCLUDED.%s",
|
||||
d.Core.QuoteWord(k),
|
||||
d.Core.QuoteWord(gconv.String(v)),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, column := range columns {
|
||||
// If it's SAVE operation, do not automatically update the creating time.
|
||||
if d.Core.IsSoftCreatedFieldName(column) {
|
||||
continue
|
||||
}
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=EXCLUDED.%s",
|
||||
d.Core.QuoteWord(column),
|
||||
d.Core.QuoteWord(column),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
conflictKeys := gstr.Join(option.OnConflict, ",")
|
||||
|
||||
return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil
|
||||
}
|
||||
|
@ -78,12 +78,12 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) {
|
||||
|
||||
if _, err := db.Exec(ctx, fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id bigserial NOT NULL,
|
||||
passport varchar(45) NOT NULL,
|
||||
password varchar(32) NOT NULL,
|
||||
nickname varchar(45) NOT NULL,
|
||||
create_time timestamp NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
id bigserial NOT NULL,
|
||||
passport varchar(45) NOT NULL,
|
||||
password varchar(32) NOT NULL,
|
||||
nickname varchar(45) NOT NULL,
|
||||
create_time timestamp NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
) ;`, name,
|
||||
)); err != nil {
|
||||
gtest.Fatal(err)
|
||||
|
@ -7,8 +7,10 @@
|
||||
package pgsql_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
@ -258,14 +260,16 @@ func Test_Model_Save(t *testing.T) {
|
||||
table := createTable()
|
||||
defer dropTable(table)
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
_, err := db.Model(table).Data(g.Map{
|
||||
result, err := db.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "t111",
|
||||
"password": "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname": "T111",
|
||||
"create_time": "2018-10-24 10:00:00",
|
||||
}).Save()
|
||||
t.Assert(err, "Save operation is not supported by pgsql driver")
|
||||
}).OnConflict("id").Save()
|
||||
t.AssertNil(err)
|
||||
n, _ := result.RowsAffected()
|
||||
t.Assert(n, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@ -284,3 +288,259 @@ func Test_Model_Replace(t *testing.T) {
|
||||
t.Assert(err, "Replace operation is not supported by pgsql driver")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Model_OnConflict(t *testing.T) {
|
||||
var (
|
||||
table = fmt.Sprintf(`%s_%d`, TablePrefix+"test", gtime.TimestampNano())
|
||||
uniqueName = fmt.Sprintf(`%s_%d`, TablePrefix+"test_unique", gtime.TimestampNano())
|
||||
)
|
||||
if _, err := db.Exec(ctx, fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id bigserial NOT NULL,
|
||||
passport varchar(45) NOT NULL,
|
||||
password varchar(32) NOT NULL,
|
||||
nickname varchar(45) NOT NULL,
|
||||
create_time timestamp NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT %s UNIQUE ("passport", "password")
|
||||
) ;`, table, uniqueName,
|
||||
)); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
defer dropTable(table)
|
||||
|
||||
// string type 1.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("passport,password").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).Where("id", 1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "n1")
|
||||
})
|
||||
|
||||
// string type 2.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("passport", "password").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).Where("id", 1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "n1")
|
||||
})
|
||||
|
||||
// slice.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict(g.Slice{"passport", "password"}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).Where("id", 1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "n1")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Model_OnDuplicate(t *testing.T) {
|
||||
table := createInitTable()
|
||||
defer dropTable(table)
|
||||
|
||||
// string type 1.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicate("passport,password").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// string type 2.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicate("passport", "password").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// slice.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Slice{"passport", "password"}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// map.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
|
||||
"passport": "nickname",
|
||||
"password": "nickname",
|
||||
}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["nickname"])
|
||||
t.Assert(one["password"], data["nickname"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// map+raw.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.MapStrStr{
|
||||
"id": "1",
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{
|
||||
"passport": gdb.Raw("CONCAT(EXCLUDED.passport, '1')"),
|
||||
"password": gdb.Raw("CONCAT(EXCLUDED.password, '2')"),
|
||||
}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"]+"1")
|
||||
t.Assert(one["password"], data["password"]+"2")
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Model_OnDuplicateEx(t *testing.T) {
|
||||
table := createInitTable()
|
||||
defer dropTable(table)
|
||||
|
||||
// string type 1.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname,create_time").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// string type 2.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname", "create_time").Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// slice.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Slice{"nickname", "create_time"}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
|
||||
// map.
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
data := g.Map{
|
||||
"id": 1,
|
||||
"passport": "pp1",
|
||||
"password": "pw1",
|
||||
"nickname": "n1",
|
||||
"create_time": "2016-06-06",
|
||||
}
|
||||
_, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Map{
|
||||
"nickname": "nickname",
|
||||
"create_time": "nickname",
|
||||
}).Data(data).Save()
|
||||
t.AssertNil(err)
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["passport"], data["passport"])
|
||||
t.Assert(one["password"], data["password"])
|
||||
t.Assert(one["nickname"], "name_1")
|
||||
})
|
||||
}
|
||||
|
@ -175,6 +175,7 @@ type DB interface {
|
||||
ConvertValueForField(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForField
|
||||
ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal
|
||||
CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (LocalType, error) // See Core.CheckLocalTypeForField
|
||||
FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) // See Core.DoFormatUpsert
|
||||
}
|
||||
|
||||
// TX defines the interfaces for ORM transaction operations.
|
||||
@ -320,6 +321,7 @@ type Sql struct {
|
||||
type DoInsertOption struct {
|
||||
OnDuplicateStr string // Custom string for `on duplicated` statement.
|
||||
OnDuplicateMap map[string]interface{} // Custom key-value map from `OnDuplicateEx` function for `on duplicated` statement.
|
||||
OnConflict []string // Custom conflict key of upsert clause, if the database needs it.
|
||||
InsertOption InsertOption // Insert operation in constant value.
|
||||
BatchCount int // Batch count for batch inserting.
|
||||
}
|
||||
|
@ -487,9 +487,12 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List,
|
||||
keysStr = charL + strings.Join(keys, charR+","+charL) + charR
|
||||
operation = GetInsertOperationByOption(option.InsertOption)
|
||||
)
|
||||
// `ON DUPLICATED...` statement only takes effect on Save operation.
|
||||
// Upsert clause only takes effect on Save operation.
|
||||
if option.InsertOption == InsertOptionSave {
|
||||
onDuplicateStr = c.formatOnDuplicate(keys, option)
|
||||
onDuplicateStr, err = c.db.FormatUpsert(keys, list, option)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var (
|
||||
listLength = len(list)
|
||||
@ -537,49 +540,6 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List,
|
||||
return batchResult, nil
|
||||
}
|
||||
|
||||
func (c *Core) formatOnDuplicate(columns []string, option DoInsertOption) string {
|
||||
var onDuplicateStr string
|
||||
if option.OnDuplicateStr != "" {
|
||||
onDuplicateStr = option.OnDuplicateStr
|
||||
} else if len(option.OnDuplicateMap) > 0 {
|
||||
for k, v := range option.OnDuplicateMap {
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
switch v.(type) {
|
||||
case Raw, *Raw:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=%s",
|
||||
c.QuoteWord(k),
|
||||
v,
|
||||
)
|
||||
default:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(k),
|
||||
c.QuoteWord(gconv.String(v)),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, column := range columns {
|
||||
// If it's SAVE operation, do not automatically update the creating time.
|
||||
if c.isSoftCreatedFieldName(column) {
|
||||
continue
|
||||
}
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(column),
|
||||
c.QuoteWord(column),
|
||||
)
|
||||
}
|
||||
}
|
||||
return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr
|
||||
}
|
||||
|
||||
// Update does "UPDATE ... " statement for the table.
|
||||
//
|
||||
// The parameter `data` can be type of string/map/gmap/struct/*struct, etc.
|
||||
@ -798,8 +758,8 @@ func (c *Core) GetTablesWithCache() ([]string, error) {
|
||||
return result.Strings(), nil
|
||||
}
|
||||
|
||||
// isSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time.
|
||||
func (c *Core) isSoftCreatedFieldName(fieldName string) bool {
|
||||
// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time.
|
||||
func (c *Core) IsSoftCreatedFieldName(fieldName string) bool {
|
||||
if fieldName == "" {
|
||||
return false
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ package gdb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
@ -352,6 +353,52 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt
|
||||
return out.Stmt, err
|
||||
}
|
||||
|
||||
// FormatUpsert formats and returns SQL clause part for upsert statement.
|
||||
// In default implements, this function performs upsert statement for MySQL like:
|
||||
// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...`
|
||||
func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) {
|
||||
var onDuplicateStr string
|
||||
if option.OnDuplicateStr != "" {
|
||||
onDuplicateStr = option.OnDuplicateStr
|
||||
} else if len(option.OnDuplicateMap) > 0 {
|
||||
for k, v := range option.OnDuplicateMap {
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
switch v.(type) {
|
||||
case Raw, *Raw:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=%s",
|
||||
c.QuoteWord(k),
|
||||
v,
|
||||
)
|
||||
default:
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(k),
|
||||
c.QuoteWord(gconv.String(v)),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, column := range columns {
|
||||
// If it's SAVE operation, do not automatically update the creating time.
|
||||
if c.IsSoftCreatedFieldName(column) {
|
||||
continue
|
||||
}
|
||||
if len(onDuplicateStr) > 0 {
|
||||
onDuplicateStr += ","
|
||||
}
|
||||
onDuplicateStr += fmt.Sprintf(
|
||||
"%s=VALUES(%s)",
|
||||
c.QuoteWord(column),
|
||||
c.QuoteWord(column),
|
||||
)
|
||||
}
|
||||
}
|
||||
return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
|
||||
}
|
||||
|
||||
// RowsToResult converts underlying data record type sql.Rows to Result type.
|
||||
func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) {
|
||||
if rows == nil {
|
||||
|
@ -48,8 +48,9 @@ type Model struct {
|
||||
hookHandler HookHandler // Hook functions for model hook feature.
|
||||
unscoped bool // Disables soft deleting features when select/delete operations.
|
||||
safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model.
|
||||
onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement.
|
||||
onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns ON "DUPLICATE KEY UPDATE" statement.
|
||||
onDuplicate interface{} // onDuplicate is used for on Upsert clause.
|
||||
onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns on Upsert clause.
|
||||
onConflict interface{} // onConflict is used for conflict keys on Upsert clause.
|
||||
tableAliasMap map[string]string // Table alias to true table name, usually used in join statements.
|
||||
softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model.
|
||||
}
|
||||
|
@ -118,8 +118,24 @@ func (m *Model) Data(data ...interface{}) *Model {
|
||||
return model
|
||||
}
|
||||
|
||||
// OnConflict sets the primary key or index when columns conflicts occurs.
|
||||
// It's not necessary for MySQL driver.
|
||||
func (m *Model) OnConflict(onConflict ...interface{}) *Model {
|
||||
if len(onConflict) == 0 {
|
||||
return m
|
||||
}
|
||||
model := m.getModel()
|
||||
if len(onConflict) > 1 {
|
||||
model.onConflict = onConflict
|
||||
} else if len(onConflict) == 1 {
|
||||
model.onConflict = onConflict[0]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// OnDuplicate sets the operations when columns conflicts occurs.
|
||||
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
|
||||
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
|
||||
// The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice.
|
||||
// Example:
|
||||
//
|
||||
@ -148,6 +164,7 @@ func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model {
|
||||
|
||||
// OnDuplicateEx sets the excluding columns for operations when columns conflict occurs.
|
||||
// In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
|
||||
// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
|
||||
// The parameter `onDuplicateEx` can be type of string/map/slice.
|
||||
// Example:
|
||||
//
|
||||
@ -320,63 +337,71 @@ func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []st
|
||||
InsertOption: insertOption,
|
||||
BatchCount: m.getBatch(),
|
||||
}
|
||||
if insertOption == InsertOptionSave {
|
||||
onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
|
||||
if err != nil {
|
||||
return option, err
|
||||
}
|
||||
onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
|
||||
if m.onDuplicate != nil {
|
||||
switch m.onDuplicate.(type) {
|
||||
case Raw, *Raw:
|
||||
option.OnDuplicateStr = gconv.String(m.onDuplicate)
|
||||
if insertOption != InsertOptionSave {
|
||||
return
|
||||
}
|
||||
|
||||
onConflictKeys, err := m.formatOnConflictKeys(m.onConflict)
|
||||
if err != nil {
|
||||
return option, err
|
||||
}
|
||||
option.OnConflict = onConflictKeys
|
||||
|
||||
onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
|
||||
if err != nil {
|
||||
return option, err
|
||||
}
|
||||
onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
|
||||
if m.onDuplicate != nil {
|
||||
switch m.onDuplicate.(type) {
|
||||
case Raw, *Raw:
|
||||
option.OnDuplicateStr = gconv.String(m.onDuplicate)
|
||||
|
||||
default:
|
||||
reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for k, v := range gconv.Map(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(k) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[k] = v
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range gconv.Strings(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
default:
|
||||
reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for k, v := range gconv.Map(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(k) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[k] = v
|
||||
}
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range gconv.Strings(m.onDuplicate) {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
|
||||
default:
|
||||
return option, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported OnDuplicate parameter type "%s"`,
|
||||
reflect.TypeOf(m.onDuplicate),
|
||||
)
|
||||
}
|
||||
return option, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported OnDuplicate parameter type "%s"`,
|
||||
reflect.TypeOf(m.onDuplicate),
|
||||
)
|
||||
}
|
||||
} else if onDuplicateExKeySet.Size() > 0 {
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range columnNames {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
} else if onDuplicateExKeySet.Size() > 0 {
|
||||
option.OnDuplicateMap = make(map[string]interface{})
|
||||
for _, v := range columnNames {
|
||||
if onDuplicateExKeySet.Contains(v) {
|
||||
continue
|
||||
}
|
||||
option.OnDuplicateMap[v] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -407,6 +432,28 @@ func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, er
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) formatOnConflictKeys(onConflict interface{}) ([]string, error) {
|
||||
if onConflict == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reflectInfo := reflection.OriginValueAndKind(onConflict)
|
||||
switch reflectInfo.OriginKind {
|
||||
case reflect.String:
|
||||
return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
return gconv.Strings(onConflict), nil
|
||||
|
||||
default:
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeInvalidParameter,
|
||||
`unsupported onConflict parameter type "%s"`,
|
||||
reflect.TypeOf(onConflict),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) getBatch() int {
|
||||
return m.batch
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user