gf/contrib/drivers/oracle/oracle_do_insert.go

209 lines
5.7 KiB
Go

// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package oracle
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/util/gconv"
)
// DoInsert inserts or updates data for given table.
func (d *Driver) DoInsert(
ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
switch option.InsertOption {
case gdb.InsertOptionSave:
return d.doSave(ctx, link, table, list, option)
case gdb.InsertOptionReplace:
return nil, gerror.NewCode(
gcode.CodeNotSupported,
`Replace operation is not supported by oracle driver`,
)
}
var (
keys []string
values []string
params []interface{}
)
// Retrieve the table fields and length.
var (
listLength = len(list)
valueHolder = make([]string, 0)
)
for k := range list[0] {
keys = append(keys, k)
valueHolder = append(valueHolder, "?")
}
var (
batchResult = new(gdb.SqlResult)
charL, charR = d.GetChars()
keyStr = charL + strings.Join(keys, charL+","+charR) + charR
valueHolderStr = strings.Join(valueHolder, ",")
)
// Format "INSERT...INTO..." statement.
intoStrArray := make([]string, 0)
for i := 0; i < len(list); i++ {
for _, k := range keys {
if s, ok := list[i][k].(gdb.Raw); ok {
params = append(params, gconv.String(s))
} else {
params = append(params, list[i][k])
}
}
values = append(values, valueHolderStr)
intoStrArray = append(
intoStrArray,
fmt.Sprintf(
"INTO %s(%s) VALUES(%s)",
table, keyStr, valueHolderStr,
),
)
if len(intoStrArray) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) {
r, err := d.DoExec(ctx, link, fmt.Sprintf(
"INSERT ALL %s SELECT * FROM DUAL",
strings.Join(intoStrArray, " "),
), params...)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
params = params[:0]
intoStrArray = intoStrArray[:0]
}
}
return batchResult, nil
}
// doSave support upsert for Oracle
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}
if len(list) == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
)
}
var (
one = list[0]
oneLen = len(one)
charL, charR = d.GetChars()
conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)
// queryHolders: Handle data with Holder that need to be upsert
// queryValues: Handle data that need to be upsert
// insertKeys: Handle valid keys that need to be inserted
// insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated
queryHolders = make([]string, oneLen)
queryValues = make([]interface{}, oneLen)
insertKeys = make([]string, oneLen)
insertValues = make([]string, oneLen)
updateValues []string
)
// conflictKeys slice type conv to set type
for _, conflictKey := range conflictKeys {
conflictKeySet.Add(gstr.ToUpper(conflictKey))
}
index := 0
for key, value := range one {
keyWithChar := charL + key + charR
queryHolders[index] = fmt.Sprintf("? AS %s", keyWithChar)
queryValues[index] = value
insertKeys[index] = keyWithChar
insertValues[index] = fmt.Sprintf("T2.%s", keyWithChar)
// filter conflict keys in updateValues.
// And the key is not a soft created field.
if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, keyWithChar, keyWithChar),
)
}
index++
}
batchResult := new(gdb.SqlResult)
sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys)
r, err := d.DoExec(ctx, link, sqlStr, queryValues...)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
return batchResult, nil
}
// parseSqlForUpsert
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryHolders}} FROM DUAL T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
func parseSqlForUpsert(table string,
queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string,
) (sqlStr string) {
var (
queryHolderStr = strings.Join(queryHolders, ",")
insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
duplicateKeyStr string
pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`)
)
for index, keys := range duplicateKey {
if index != 0 {
duplicateKeyStr += " AND "
}
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
duplicateKeyStr += duplicateTmp
}
return fmt.Sprintf(pattern,
table,
queryHolderStr,
duplicateKeyStr,
insertKeyStr,
insertValueStr,
updateValueStr,
)
}