From 509cc47d3f955f4a0a60655a2e5e5015ecfc159b Mon Sep 17 00:00:00 2001 From: oldme <45782393+oldme-git@users.noreply.github.com> Date: Mon, 4 Mar 2024 20:17:43 +0800 Subject: [PATCH] enhance: add `Save` operation support for pgsql #3053 (#3324) --- contrib/drivers/pgsql/pgsql_do_insert.go | 61 +++- .../drivers/pgsql/pgsql_z_unit_init_test.go | 12 +- .../drivers/pgsql/pgsql_z_unit_model_test.go | 266 +++++++++++++++++- database/gdb/gdb.go | 2 + database/gdb/gdb_core.go | 54 +--- database/gdb/gdb_core_underlying.go | 47 ++++ database/gdb/gdb_model.go | 5 +- database/gdb/gdb_model_insert.go | 153 ++++++---- 8 files changed, 483 insertions(+), 117 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index b7586e3d7..84995ff37 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -9,21 +9,18 @@ package pgsql import ( "context" "database/sql" + "fmt" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" ) // DoInsert inserts or updates data forF given table. func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { switch option.InsertOption { - case gdb.InsertOptionSave: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by pgsql driver`, - ) - case gdb.InsertOptionReplace: return nil, gerror.NewCode( gcode.CodeNotSupported, @@ -50,3 +47,55 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list } return d.Core.DoInsert(ctx, link, table, list, option) } + +// FormatUpsert returns SQL clause of type upsert for PgSQL. +// For example: ON CONFLICT (id) DO UPDATE SET ... +func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { + if len(option.OnConflict) == 0 { + return "", gerror.New("Please specify conflict columns") + } + + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case gdb.Raw, *gdb.Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + d.Core.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(k), + d.Core.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if d.Core.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(column), + d.Core.QuoteWord(column), + ) + } + } + + conflictKeys := gstr.Join(option.OnConflict, ",") + + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil +} diff --git a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go index 8e67af502..c2033e301 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go @@ -78,12 +78,12 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { if _, err := db.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( - id bigserial NOT NULL, - passport varchar(45) NOT NULL, - password varchar(32) NOT NULL, - nickname varchar(45) NOT NULL, - create_time timestamp NOT NULL, - PRIMARY KEY (id) + id bigserial NOT NULL, + passport varchar(45) NOT NULL, + password varchar(32) NOT NULL, + nickname varchar(45) NOT NULL, + create_time timestamp NOT NULL, + PRIMARY KEY (id) ) ;`, name, )); err != nil { gtest.Fatal(err) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 26ce0797a..81411ec49 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -7,8 +7,10 @@ package pgsql_test import ( + "fmt" "testing" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" @@ -258,14 +260,16 @@ func Test_Model_Save(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + result, err := db.Model(table).Data(g.Map{ "id": 1, "passport": "t111", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T111", "create_time": "2018-10-24 10:00:00", - }).Save() - t.Assert(err, "Save operation is not supported by pgsql driver") + }).OnConflict("id").Save() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) }) } @@ -284,3 +288,259 @@ func Test_Model_Replace(t *testing.T) { t.Assert(err, "Replace operation is not supported by pgsql driver") }) } + +func Test_Model_OnConflict(t *testing.T) { + var ( + table = fmt.Sprintf(`%s_%d`, TablePrefix+"test", gtime.TimestampNano()) + uniqueName = fmt.Sprintf(`%s_%d`, TablePrefix+"test_unique", gtime.TimestampNano()) + ) + if _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id bigserial NOT NULL, + passport varchar(45) NOT NULL, + password varchar(32) NOT NULL, + nickname varchar(45) NOT NULL, + create_time timestamp NOT NULL, + PRIMARY KEY (id), + CONSTRAINT %s UNIQUE ("passport", "password") + ) ;`, table, uniqueName, + )); err != nil { + gtest.Fatal(err) + } + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("passport,password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).Where("id", 1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("passport", "password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).Where("id", 1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict(g.Slice{"passport", "password"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).Where("id", 1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) +} + +func Test_Model_OnDuplicate(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate("passport,password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate("passport", "password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Slice{"passport", "password"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "passport": "nickname", + "password": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["nickname"]) + t.Assert(one["password"], data["nickname"]) + t.Assert(one["nickname"], "name_1") + }) + + // map+raw. + gtest.C(t, func(t *gtest.T) { + data := g.MapStrStr{ + "id": "1", + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "passport": gdb.Raw("CONCAT(EXCLUDED.passport, '1')"), + "password": gdb.Raw("CONCAT(EXCLUDED.password, '2')"), + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]+"1") + t.Assert(one["password"], data["password"]+"2") + t.Assert(one["nickname"], "name_1") + }) +} + +func Test_Model_OnDuplicateEx(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname,create_time").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname", "create_time").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Slice{"nickname", "create_time"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Map{ + "nickname": "nickname", + "create_time": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) +} diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 89f02b799..d63c12ffe 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -175,6 +175,7 @@ type DB interface { ConvertValueForField(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForField ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (LocalType, error) // See Core.CheckLocalTypeForField + FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) // See Core.DoFormatUpsert } // TX defines the interfaces for ORM transaction operations. @@ -320,6 +321,7 @@ type Sql struct { type DoInsertOption struct { OnDuplicateStr string // Custom string for `on duplicated` statement. OnDuplicateMap map[string]interface{} // Custom key-value map from `OnDuplicateEx` function for `on duplicated` statement. + OnConflict []string // Custom conflict key of upsert clause, if the database needs it. InsertOption InsertOption // Insert operation in constant value. BatchCount int // Batch count for batch inserting. } diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index c0292d042..c48f45dd5 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -487,9 +487,12 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, keysStr = charL + strings.Join(keys, charR+","+charL) + charR operation = GetInsertOperationByOption(option.InsertOption) ) - // `ON DUPLICATED...` statement only takes effect on Save operation. + // Upsert clause only takes effect on Save operation. if option.InsertOption == InsertOptionSave { - onDuplicateStr = c.formatOnDuplicate(keys, option) + onDuplicateStr, err = c.db.FormatUpsert(keys, list, option) + if err != nil { + return nil, err + } } var ( listLength = len(list) @@ -537,49 +540,6 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, return batchResult, nil } -func (c *Core) formatOnDuplicate(columns []string, option DoInsertOption) string { - var onDuplicateStr string - if option.OnDuplicateStr != "" { - onDuplicateStr = option.OnDuplicateStr - } else if len(option.OnDuplicateMap) > 0 { - for k, v := range option.OnDuplicateMap { - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - switch v.(type) { - case Raw, *Raw: - onDuplicateStr += fmt.Sprintf( - "%s=%s", - c.QuoteWord(k), - v, - ) - default: - onDuplicateStr += fmt.Sprintf( - "%s=VALUES(%s)", - c.QuoteWord(k), - c.QuoteWord(gconv.String(v)), - ) - } - } - } else { - for _, column := range columns { - // If it's SAVE operation, do not automatically update the creating time. - if c.isSoftCreatedFieldName(column) { - continue - } - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - onDuplicateStr += fmt.Sprintf( - "%s=VALUES(%s)", - c.QuoteWord(column), - c.QuoteWord(column), - ) - } - } - return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr -} - // Update does "UPDATE ... " statement for the table. // // The parameter `data` can be type of string/map/gmap/struct/*struct, etc. @@ -798,8 +758,8 @@ func (c *Core) GetTablesWithCache() ([]string, error) { return result.Strings(), nil } -// isSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time. -func (c *Core) isSoftCreatedFieldName(fieldName string) bool { +// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time. +func (c *Core) IsSoftCreatedFieldName(fieldName string) bool { if fieldName == "" { return false } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 85b0aaa84..d3b5b5b88 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -10,6 +10,7 @@ package gdb import ( "context" "database/sql" + "fmt" "reflect" "go.opentelemetry.io/otel" @@ -352,6 +353,52 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt return out.Stmt, err } +// FormatUpsert formats and returns SQL clause part for upsert statement. +// In default implements, this function performs upsert statement for MySQL like: +// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...` +func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) { + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case Raw, *Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + c.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(k), + c.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if c.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(column), + c.QuoteWord(column), + ) + } + } + return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil +} + // RowsToResult converts underlying data record type sql.Rows to Result type. func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) { if rows == nil { diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index f85406364..ab9cf4317 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -48,8 +48,9 @@ type Model struct { hookHandler HookHandler // Hook functions for model hook feature. unscoped bool // Disables soft deleting features when select/delete operations. safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. - onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement. - onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns ON "DUPLICATE KEY UPDATE" statement. + onDuplicate interface{} // onDuplicate is used for on Upsert clause. + onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns on Upsert clause. + onConflict interface{} // onConflict is used for conflict keys on Upsert clause. tableAliasMap map[string]string // Table alias to true table name, usually used in join statements. softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model. } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index a470d3d1c..3f99f23fb 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -118,8 +118,24 @@ func (m *Model) Data(data ...interface{}) *Model { return model } +// OnConflict sets the primary key or index when columns conflicts occurs. +// It's not necessary for MySQL driver. +func (m *Model) OnConflict(onConflict ...interface{}) *Model { + if len(onConflict) == 0 { + return m + } + model := m.getModel() + if len(onConflict) > 1 { + model.onConflict = onConflict + } else if len(onConflict) == 1 { + model.onConflict = onConflict[0] + } + return model +} + // OnDuplicate sets the operations when columns conflicts occurs. // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. // The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice. // Example: // @@ -148,6 +164,7 @@ func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model { // OnDuplicateEx sets the excluding columns for operations when columns conflict occurs. // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. // The parameter `onDuplicateEx` can be type of string/map/slice. // Example: // @@ -320,63 +337,71 @@ func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []st InsertOption: insertOption, BatchCount: m.getBatch(), } - if insertOption == InsertOptionSave { - onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) - if err != nil { - return option, err - } - onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys) - if m.onDuplicate != nil { - switch m.onDuplicate.(type) { - case Raw, *Raw: - option.OnDuplicateStr = gconv.String(m.onDuplicate) + if insertOption != InsertOptionSave { + return + } + + onConflictKeys, err := m.formatOnConflictKeys(m.onConflict) + if err != nil { + return option, err + } + option.OnConflict = onConflictKeys + + onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) + if err != nil { + return option, err + } + onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys) + if m.onDuplicate != nil { + switch m.onDuplicate.(type) { + case Raw, *Raw: + option.OnDuplicateStr = gconv.String(m.onDuplicate) + + default: + reflectInfo := reflection.OriginValueAndKind(m.onDuplicate) + switch reflectInfo.OriginKind { + case reflect.String: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } + + case reflect.Map: + option.OnDuplicateMap = make(map[string]interface{}) + for k, v := range gconv.Map(m.onDuplicate) { + if onDuplicateExKeySet.Contains(k) { + continue + } + option.OnDuplicateMap[k] = v + } + + case reflect.Slice, reflect.Array: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gconv.Strings(m.onDuplicate) { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v + } default: - reflectInfo := reflection.OriginValueAndKind(m.onDuplicate) - switch reflectInfo.OriginKind { - case reflect.String: - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v - } - - case reflect.Map: - option.OnDuplicateMap = make(map[string]interface{}) - for k, v := range gconv.Map(m.onDuplicate) { - if onDuplicateExKeySet.Contains(k) { - continue - } - option.OnDuplicateMap[k] = v - } - - case reflect.Slice, reflect.Array: - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range gconv.Strings(m.onDuplicate) { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v - } - - default: - return option, gerror.NewCodef( - gcode.CodeInvalidParameter, - `unsupported OnDuplicate parameter type "%s"`, - reflect.TypeOf(m.onDuplicate), - ) - } + return option, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported OnDuplicate parameter type "%s"`, + reflect.TypeOf(m.onDuplicate), + ) } - } else if onDuplicateExKeySet.Size() > 0 { - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range columnNames { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v + } + } else if onDuplicateExKeySet.Size() > 0 { + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range columnNames { + if onDuplicateExKeySet.Contains(v) { + continue } + option.OnDuplicateMap[v] = v } } return @@ -407,6 +432,28 @@ func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, er } } +func (m *Model) formatOnConflictKeys(onConflict interface{}) ([]string, error) { + if onConflict == nil { + return nil, nil + } + + reflectInfo := reflection.OriginValueAndKind(onConflict) + switch reflectInfo.OriginKind { + case reflect.String: + return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil + + case reflect.Slice, reflect.Array: + return gconv.Strings(onConflict), nil + + default: + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported onConflict parameter type "%s"`, + reflect.TypeOf(onConflict), + ) + } +} + func (m *Model) getBatch() int { return m.batch }