mirror of
https://gitee.com/johng/gf.git
synced 2024-12-02 04:07:47 +08:00
commit
24ea9f9245
119
.example/database/gdb/driver/driver/driver.go
Normal file
119
.example/database/gdb/driver/driver/driver.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/database/gdb"
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
)
|
||||
|
||||
type MyDriver struct {
|
||||
*gdb.Core
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for mysql.
|
||||
func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
} else {
|
||||
source = fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local",
|
||||
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset,
|
||||
)
|
||||
}
|
||||
intlog.Printf("Open: %s", source)
|
||||
if db, err := sql.Open("mysql", source); err == nil {
|
||||
return db, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (d *MyDriver) GetChars() (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec handles the sql before posts it to database.
|
||||
func (d *MyDriver) HandleSqlBeforeExec(sql string) string {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (d *MyDriver) Tables(schema ...string) (tables []string, err error) {
|
||||
var result gdb.Result
|
||||
link, err := d.DB.GetSlave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, m := range result {
|
||||
for _, v := range m {
|
||||
tables = append(tables, v.String())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// gdb.TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
//
|
||||
// Note that it returns a map containing the field name and its corresponding fields.
|
||||
// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
|
||||
func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function gdb.TableFields supports only single table operations")
|
||||
}
|
||||
checkSchema := d.DB.GetSchema()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := d.DB.GetCache().GetOrSetFunc(
|
||||
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
|
||||
func() interface{} {
|
||||
var result gdb.Result
|
||||
var link *sql.DB
|
||||
link, err = d.DB.GetSlave(checkSchema)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result, err = d.DB.DoGetAll(
|
||||
link,
|
||||
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
fields = make(map[string]*gdb.TableField)
|
||||
for i, m := range result {
|
||||
fields[m["Field"].String()] = &gdb.TableField{
|
||||
Index: i,
|
||||
Name: m["Field"].String(),
|
||||
Type: m["Type"].String(),
|
||||
Null: m["Null"].Bool(),
|
||||
Key: m["Key"].String(),
|
||||
Default: m["Default"].Val(),
|
||||
Extra: m["Extra"].String(),
|
||||
Comment: m["Comment"].String(),
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}, 0)
|
||||
if err == nil {
|
||||
fields = v.(map[string]*gdb.TableField)
|
||||
}
|
||||
return
|
||||
}
|
1
.example/database/gdb/driver/main.go
Normal file
1
.example/database/gdb/driver/main.go
Normal file
@ -0,0 +1 @@
|
||||
package main
|
@ -11,11 +11,11 @@ func main() {
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
r, e := db.Table("test").OrderBy("id asc").All()
|
||||
r, e := db.Table("test").Order("id asc").All()
|
||||
if e != nil {
|
||||
panic(e)
|
||||
fmt.Println(e)
|
||||
}
|
||||
if r != nil {
|
||||
fmt.Println(r.ToList())
|
||||
fmt.Println(r.List())
|
||||
}
|
||||
}
|
||||
|
@ -9,5 +9,4 @@ func main() {
|
||||
db.SetDebug(true)
|
||||
|
||||
db.Table("user").Fields("DISTINCT id,nickname").Filter().All()
|
||||
|
||||
}
|
||||
|
@ -1,37 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
)
|
||||
|
||||
func MiddlewareAuth(r *ghttp.Request) {
|
||||
token := r.Get("token")
|
||||
if token == "123456" {
|
||||
r.Response.Writeln("auth")
|
||||
r.Middleware.Next()
|
||||
} else {
|
||||
r.Response.WriteStatus(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func MiddlewareCORS(r *ghttp.Request) {
|
||||
r.Response.Writeln("cors")
|
||||
r.Response.CORSDefault()
|
||||
r.Middleware.Next()
|
||||
}
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.Use(MiddlewareCORS)
|
||||
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
||||
group.Middleware(MiddlewareAuth)
|
||||
group.ALL("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Writeln("list")
|
||||
})
|
||||
data := "@var(.prefix)您收到的验证码为:@var(.code),请在@var(.expire)内完成验证"
|
||||
result, err := gregex.ReplaceStringFuncMatch(`(@var\(\.\w+\))`, data, func(match []string) string {
|
||||
fmt.Println(match)
|
||||
return "#"
|
||||
})
|
||||
s.SetPort(8199)
|
||||
s.Run()
|
||||
fmt.Println(err)
|
||||
fmt.Println(result)
|
||||
}
|
||||
|
@ -4,13 +4,12 @@ import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := ghttp.GetServer()
|
||||
s.BindHandler("/page/demo", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String())
|
||||
page := r.GetPage(100, 10)
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
|
@ -4,14 +4,13 @@ import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := ghttp.GetServer()
|
||||
s.BindHandler("/page/ajax", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
|
||||
page.EnableAjax("DoAjax")
|
||||
page := r.GetPage(100, 10)
|
||||
page.AjaxActionName = "DoAjax"
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
@ -29,11 +28,17 @@ func main() {
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>{{.page}}</div>
|
||||
<div>{{.page1}}</div>
|
||||
<div>{{.page2}}</div>
|
||||
<div>{{.page3}}</div>
|
||||
<div>{{.page4}}</div>
|
||||
</body>
|
||||
</html>
|
||||
`, g.Map{
|
||||
"page": page.GetContent(1),
|
||||
"page1": page.GetContent(1),
|
||||
"page2": page.GetContent(2),
|
||||
"page3": page.GetContent(3),
|
||||
"page4": page.GetContent(4),
|
||||
})
|
||||
r.Response.Write(buffer)
|
||||
})
|
||||
|
@ -23,7 +23,7 @@ func wrapContent(page *gpage.Page) string {
|
||||
func main() {
|
||||
s := ghttp.GetServer()
|
||||
s.BindHandler("/page/custom1/*page", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
|
||||
page := r.GetPage(100, 10)
|
||||
content := wrapContent(page)
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
|
@ -15,7 +15,7 @@ func pageContent(page *gpage.Page) string {
|
||||
page.LastPageTag = "LastPage"
|
||||
pageStr := page.FirstPage()
|
||||
pageStr += page.PrevPage()
|
||||
pageStr += page.PageBar("current-page")
|
||||
pageStr += page.PageBar()
|
||||
pageStr += page.NextPage()
|
||||
pageStr += page.LastPage()
|
||||
return pageStr
|
||||
@ -24,7 +24,7 @@ func pageContent(page *gpage.Page) string {
|
||||
func main() {
|
||||
s := ghttp.GetServer()
|
||||
s.BindHandler("/page/custom2/*page", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
|
||||
page := r.GetPage(100, 10)
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
|
@ -4,13 +4,12 @@ import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.BindHandler("/page/static/*page", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
|
||||
page := r.GetPage(100, 10)
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
|
@ -4,13 +4,12 @@ import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.BindHandler("/:obj/*action/{page}.html", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
|
||||
page := r.GetPage(100, 10)
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
|
@ -4,14 +4,13 @@ import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.BindHandler("/page/template/{page}.html", func(r *ghttp.Request) {
|
||||
page := gpage.New(100, 10, r.Get("page"), r.URL.String())
|
||||
page.SetUrlTemplate("/order/list/{.page}.html")
|
||||
page := r.GetPage(100, 10)
|
||||
page.UrlTemplate = "/order/list/{.page}.html"
|
||||
buffer, _ := gview.ParseContent(`
|
||||
<html>
|
||||
<head>
|
||||
|
@ -15,7 +15,7 @@ We currently accept donation by Alipay/WechatPay, please note your github/gitee
|
||||
|[zhuhuan12](https://gitee.com/zhuhuan12)|gitee|¥50.00 |
|
||||
|[zfan_codes](https://gitee.com/zfan_codes)|gitee|¥10.00 |
|
||||
|[arden](https://github.com/arden)|alipay|¥10.00 |
|
||||
|[macnie](https://www.macnie.com)|wechat|¥100.00 |
|
||||
|[macnie](https://www.macnie.com)|wechat|¥110.00 |
|
||||
|lah|wechat|¥100.00 |
|
||||
|x*z|wechat|¥20.00 |
|
||||
|潘兄|wechat|¥100.00 |
|
||||
|
@ -439,6 +439,9 @@ func (a *SortedArray) SetUnique(unique bool) *SortedArray {
|
||||
func (a *SortedArray) Unique() *SortedArray {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if len(a.array) == 0 {
|
||||
return a
|
||||
}
|
||||
i := 0
|
||||
for {
|
||||
if i == len(a.array)-1 {
|
||||
|
@ -429,6 +429,10 @@ func (a *SortedIntArray) SetUnique(unique bool) *SortedIntArray {
|
||||
// Unique uniques the array, clear repeated items.
|
||||
func (a *SortedIntArray) Unique() *SortedIntArray {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if len(a.array) == 0 {
|
||||
return a
|
||||
}
|
||||
i := 0
|
||||
for {
|
||||
if i == len(a.array)-1 {
|
||||
@ -440,7 +444,6 @@ func (a *SortedIntArray) Unique() *SortedIntArray {
|
||||
i++
|
||||
}
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return a
|
||||
}
|
||||
|
||||
|
@ -414,6 +414,10 @@ func (a *SortedStrArray) SetUnique(unique bool) *SortedStrArray {
|
||||
// Unique uniques the array, clear repeated items.
|
||||
func (a *SortedStrArray) Unique() *SortedStrArray {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if len(a.array) == 0 {
|
||||
return a
|
||||
}
|
||||
i := 0
|
||||
for {
|
||||
if i == len(a.array)-1 {
|
||||
@ -425,7 +429,6 @@ func (a *SortedStrArray) Unique() *SortedStrArray {
|
||||
i++
|
||||
}
|
||||
}
|
||||
a.mu.Unlock()
|
||||
return a
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func Example_basic() {
|
||||
// 创建普通的数组,默认并发安全(带锁)
|
||||
a := garray.New(true)
|
||||
// 创建普通的数组
|
||||
a := garray.New()
|
||||
|
||||
// 添加数据项
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -4,7 +4,7 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package gqueue provides a dynamic/static concurrent-safe queue.
|
||||
// Package gqueue provides dynamic/static concurrent-safe queue.
|
||||
//
|
||||
// Features:
|
||||
//
|
||||
@ -25,6 +25,7 @@ import (
|
||||
"github.com/gogf/gf/container/gtype"
|
||||
)
|
||||
|
||||
// Queue is a concurrent-safe queue built on doubly linked list and channel.
|
||||
type Queue struct {
|
||||
limit int // Limit for queue size.
|
||||
list *glist.List // Underlying list structure for data maintaining.
|
||||
@ -54,14 +55,14 @@ func New(limit ...int) *Queue {
|
||||
q.list = glist.New(true)
|
||||
q.events = make(chan struct{}, math.MaxInt32)
|
||||
q.C = make(chan interface{}, gDEFAULT_QUEUE_SIZE)
|
||||
go q.startAsyncLoop()
|
||||
go q.asyncLoopFromListToChannel()
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// startAsyncLoop starts an asynchronous goroutine,
|
||||
// asyncLoopFromListToChannel starts an asynchronous goroutine,
|
||||
// which handles the data synchronization from list <q.list> to channel <q.C>.
|
||||
func (q *Queue) startAsyncLoop() {
|
||||
func (q *Queue) asyncLoopFromListToChannel() {
|
||||
defer func() {
|
||||
if q.closed.Val() {
|
||||
_ = recover()
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
"github.com/gogf/gf/util/grand"
|
||||
)
|
||||
|
||||
// DB is the interface for ORM operations.
|
||||
// DB defines the interfaces for ORM operations.
|
||||
type DB interface {
|
||||
// Open creates a raw connection object for database with given node configuration.
|
||||
// Note that it is not recommended using the this function manually.
|
||||
@ -34,14 +34,14 @@ type DB interface {
|
||||
Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error)
|
||||
|
||||
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
|
||||
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error)
|
||||
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
|
||||
doPrepare(link dbLink, query string) (*sql.Stmt, error)
|
||||
doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
DoGetAll(link Link, query string, args ...interface{}) (result Result, err error)
|
||||
DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error)
|
||||
DoPrepare(link Link, query string) (*sql.Stmt, error)
|
||||
DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
|
||||
// Query APIs for convenience purpose.
|
||||
GetAll(query string, args ...interface{}) (Result, error)
|
||||
@ -52,11 +52,11 @@ type DB interface {
|
||||
GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error
|
||||
GetScan(objPointer interface{}, query string, args ...interface{}) error
|
||||
|
||||
// Master/Slave support.
|
||||
// Master/Slave specification support.
|
||||
Master() (*sql.DB, error)
|
||||
Slave() (*sql.DB, error)
|
||||
|
||||
// Ping.
|
||||
// Ping-Pong.
|
||||
PingMaster() error
|
||||
PingSlave() error
|
||||
|
||||
@ -75,48 +75,49 @@ type DB interface {
|
||||
Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error)
|
||||
Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error)
|
||||
|
||||
// Create model.
|
||||
// Model creation.
|
||||
From(tables string) *Model
|
||||
Table(tables string) *Model
|
||||
Schema(schema string) *Schema
|
||||
|
||||
// Configuration methods.
|
||||
GetCache() *gcache.Cache
|
||||
SetDebug(debug bool)
|
||||
GetDebug() bool
|
||||
SetSchema(schema string)
|
||||
GetSchema() string
|
||||
GetPrefix() string
|
||||
SetLogger(logger *glog.Logger)
|
||||
GetLogger() *glog.Logger
|
||||
SetMaxIdleConnCount(n int)
|
||||
SetMaxOpenConnCount(n int)
|
||||
SetMaxConnLifetime(d time.Duration)
|
||||
|
||||
// Utility methods.
|
||||
GetChars() (charLeft string, charRight string)
|
||||
GetMaster(schema ...string) (*sql.DB, error)
|
||||
GetSlave(schema ...string) (*sql.DB, error)
|
||||
QuoteWord(s string) string
|
||||
QuoteString(s string) string
|
||||
QuotePrefixTableName(table string) string
|
||||
Tables(schema ...string) (tables []string, err error)
|
||||
TableFields(table string, schema ...string) (map[string]*TableField, error)
|
||||
|
||||
// HandleSqlBeforeCommit is a hook function, which deals with the sql string before
|
||||
// it's committed to underlying driver. The parameter <link> specifies the current
|
||||
// database connection operation object. You can modify the sql string <query> and its
|
||||
// arguments <args> as you wish before they're committed to driver.
|
||||
HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{})
|
||||
|
||||
// Internal methods.
|
||||
getCache() *gcache.Cache
|
||||
getChars() (charLeft string, charRight string)
|
||||
getDebug() bool
|
||||
getPrefix() string
|
||||
getMaster(schema ...string) (*sql.DB, error)
|
||||
getSlave(schema ...string) (*sql.DB, error)
|
||||
quoteWord(s string) string
|
||||
quoteString(s string) string
|
||||
handleTableName(table string) string
|
||||
filterFields(schema, table string, data map[string]interface{}) map[string]interface{}
|
||||
convertValue(fieldValue []byte, fieldType string) interface{}
|
||||
rowsToResult(rows *sql.Rows) (Result, error)
|
||||
handleSqlBeforeExec(sql string) string
|
||||
}
|
||||
|
||||
// dbLink is a common database function wrapper interface for internal usage.
|
||||
type dbLink interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// dbBase is the base struct for database management.
|
||||
type dbBase struct {
|
||||
db DB // DB interface object.
|
||||
// Core is the base struct for database management.
|
||||
type Core struct {
|
||||
DB DB // DB interface object.
|
||||
group string // Configuration group name.
|
||||
debug *gtype.Bool // Enable debug mode for the database.
|
||||
cache *gcache.Cache // Cache manager.
|
||||
@ -128,6 +129,12 @@ type dbBase struct {
|
||||
maxConnLifetime time.Duration // Max TTL for a connection.
|
||||
}
|
||||
|
||||
// Driver is the interface for integrating sql drivers into package gdb.
|
||||
type Driver interface {
|
||||
// New creates and returns a database object for specified database server.
|
||||
New(core *Core, node *ConfigNode) (DB, error)
|
||||
}
|
||||
|
||||
// Sql is the sql recording struct.
|
||||
type Sql struct {
|
||||
Sql string // SQL string(may contain reserved char '?').
|
||||
@ -150,6 +157,13 @@ type TableField struct {
|
||||
Comment string // Comment.
|
||||
}
|
||||
|
||||
// Link is a common database function wrapper interface.
|
||||
type Link interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// Value is the field value type.
|
||||
type Value = *gvar.Var
|
||||
|
||||
@ -176,10 +190,24 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// Instance map.
|
||||
// instances is the management map for instances.
|
||||
instances = gmap.NewStrAnyMap(true)
|
||||
// driverMap manages all custom registered driver.
|
||||
driverMap = map[string]Driver{
|
||||
"mysql": &DriverMysql{},
|
||||
"mssql": &DriverMssql{},
|
||||
"pgsql": &DriverPgsql{},
|
||||
"oracle": &DriverOracle{},
|
||||
"sqlite": &DriverSqlite{},
|
||||
}
|
||||
)
|
||||
|
||||
// Register registers custom database driver to gdb.
|
||||
func Register(name string, driver Driver) error {
|
||||
driverMap[name] = driver
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates and returns an ORM object with global configurations.
|
||||
// The parameter <name> specifies the configuration group name,
|
||||
// which is DEFAULT_GROUP_NAME in default.
|
||||
@ -196,31 +224,24 @@ func New(name ...string) (db DB, err error) {
|
||||
}
|
||||
if _, ok := configs.config[group]; ok {
|
||||
if node, err := getConfigNodeByGroup(group, true); err == nil {
|
||||
base := &dbBase{
|
||||
group: group,
|
||||
debug: gtype.NewBool(),
|
||||
cache: gcache.New(),
|
||||
schema: gtype.NewString(),
|
||||
logger: glog.New(),
|
||||
prefix: node.Prefix,
|
||||
// Default max connection life time if user does not configure.
|
||||
maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME,
|
||||
c := &Core{
|
||||
group: group,
|
||||
debug: gtype.NewBool(),
|
||||
cache: gcache.New(),
|
||||
schema: gtype.NewString(),
|
||||
logger: glog.New(),
|
||||
prefix: node.Prefix,
|
||||
maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME, // Default max connection life time if user does not configure.
|
||||
}
|
||||
switch node.Type {
|
||||
case "mysql":
|
||||
base.db = &dbMysql{dbBase: base}
|
||||
case "pgsql":
|
||||
base.db = &dbPgsql{dbBase: base}
|
||||
case "mssql":
|
||||
base.db = &dbMssql{dbBase: base}
|
||||
case "sqlite":
|
||||
base.db = &dbSqlite{dbBase: base}
|
||||
case "oracle":
|
||||
base.db = &dbOracle{dbBase: base}
|
||||
default:
|
||||
if v, ok := driverMap[node.Type]; ok {
|
||||
c.DB, err = v.New(c, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.DB, nil
|
||||
} else {
|
||||
return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type))
|
||||
}
|
||||
return base.db, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
@ -321,9 +342,9 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
|
||||
// getSqlDb retrieves and returns a underlying database connection object.
|
||||
// The parameter <master> specifies whether retrieves master node connection if
|
||||
// master-slave nodes are configured.
|
||||
func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
|
||||
func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
|
||||
// Load balance.
|
||||
node, err := getConfigNodeByGroup(bs.group, master)
|
||||
node, err := getConfigNodeByGroup(c.group, master)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -332,7 +353,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
|
||||
node.Charset = "utf8"
|
||||
}
|
||||
// Changes the schema.
|
||||
nodeSchema := bs.schema.Val()
|
||||
nodeSchema := c.schema.Val()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
nodeSchema = schema[0]
|
||||
}
|
||||
@ -343,25 +364,25 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
|
||||
node = &n
|
||||
}
|
||||
// Cache the underlying connection object by node.
|
||||
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = bs.db.Open(node)
|
||||
v := c.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = c.DB.Open(node)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if bs.maxIdleConnCount > 0 {
|
||||
sqlDb.SetMaxIdleConns(bs.maxIdleConnCount)
|
||||
if c.maxIdleConnCount > 0 {
|
||||
sqlDb.SetMaxIdleConns(c.maxIdleConnCount)
|
||||
} else if node.MaxIdleConnCount > 0 {
|
||||
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
|
||||
}
|
||||
|
||||
if bs.maxOpenConnCount > 0 {
|
||||
sqlDb.SetMaxOpenConns(bs.maxOpenConnCount)
|
||||
if c.maxOpenConnCount > 0 {
|
||||
sqlDb.SetMaxOpenConns(c.maxOpenConnCount)
|
||||
} else if node.MaxOpenConnCount > 0 {
|
||||
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
|
||||
}
|
||||
|
||||
if bs.maxConnLifetime > 0 {
|
||||
sqlDb.SetConnMaxLifetime(bs.maxConnLifetime * time.Second)
|
||||
if c.maxConnLifetime > 0 {
|
||||
sqlDb.SetConnMaxLifetime(c.maxConnLifetime * time.Second)
|
||||
} else if node.MaxConnLifetime > 0 {
|
||||
sqlDb.SetConnMaxLifetime(node.MaxConnLifetime * time.Second)
|
||||
}
|
||||
@ -371,40 +392,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
|
||||
sqlDb = v.(*sql.DB)
|
||||
}
|
||||
if node.Debug {
|
||||
bs.db.SetDebug(node.Debug)
|
||||
c.DB.SetDebug(node.Debug)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetSchema changes the schema for this database connection object.
|
||||
// Importantly note that when schema configuration changed for the database,
|
||||
// it affects all operations on the database object in the future.
|
||||
func (bs *dbBase) SetSchema(schema string) {
|
||||
bs.schema.Set(schema)
|
||||
}
|
||||
|
||||
// Master creates and returns a connection from master node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (bs *dbBase) Master() (*sql.DB, error) {
|
||||
return bs.getSqlDb(true, bs.schema.Val())
|
||||
}
|
||||
|
||||
// Slave creates and returns a connection from slave node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (bs *dbBase) Slave() (*sql.DB, error) {
|
||||
return bs.getSqlDb(false, bs.schema.Val())
|
||||
}
|
||||
|
||||
// getMaster acts like function Master but with additional <schema> parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Master.
|
||||
func (bs *dbBase) getMaster(schema ...string) (*sql.DB, error) {
|
||||
return bs.getSqlDb(true, schema...)
|
||||
}
|
||||
|
||||
// getSlave acts like function Slave but with additional <schema> parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Slave.
|
||||
func (bs *dbBase) getSlave(schema ...string) (*sql.DB, error) {
|
||||
return bs.getSqlDb(false, schema...)
|
||||
}
|
||||
|
@ -16,7 +16,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/container/gvar"
|
||||
"github.com/gogf/gf/os/gcache"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
"github.com/gogf/gf/util/gconv"
|
||||
@ -32,22 +31,34 @@ var (
|
||||
lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`)
|
||||
)
|
||||
|
||||
// Master creates and returns a connection from master node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (c *Core) Master() (*sql.DB, error) {
|
||||
return c.getSqlDb(true, c.schema.Val())
|
||||
}
|
||||
|
||||
// Slave creates and returns a connection from slave node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (c *Core) Slave() (*sql.DB, error) {
|
||||
return c.getSqlDb(false, c.schema.Val())
|
||||
}
|
||||
|
||||
// Query commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data querying.
|
||||
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := bs.db.Slave()
|
||||
func (c *Core) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := c.DB.Slave()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doQuery(link, query, args...)
|
||||
return c.DB.DoQuery(link, query, args...)
|
||||
}
|
||||
|
||||
// doQuery commits the query string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
func (c *Core) DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
query, args = formatQuery(query, args)
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
|
||||
if c.DB.GetDebug() {
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
rows, err = link.Query(query, args...)
|
||||
mTime2 := gtime.TimestampMilli()
|
||||
@ -59,7 +70,7 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
}
|
||||
bs.printSql(s)
|
||||
c.writeSqlToLogger(s)
|
||||
} else {
|
||||
rows, err = link.Query(query, args...)
|
||||
}
|
||||
@ -73,20 +84,20 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows
|
||||
|
||||
// Exec commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data inserting and updating.
|
||||
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := bs.db.Master()
|
||||
func (c *Core) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := c.DB.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doExec(link, query, args...)
|
||||
return c.DB.DoExec(link, query, args...)
|
||||
}
|
||||
|
||||
// doExec commits the query string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
query, args = formatQuery(query, args)
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
|
||||
if c.DB.GetDebug() {
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
result, err = link.Exec(query, args...)
|
||||
mTime2 := gtime.TimestampMilli()
|
||||
@ -98,7 +109,7 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
}
|
||||
bs.printSql(s)
|
||||
c.writeSqlToLogger(s)
|
||||
} else {
|
||||
result, err = link.Exec(query, args...)
|
||||
}
|
||||
@ -113,50 +124,50 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result
|
||||
//
|
||||
// The parameter <execOnMaster> specifies whether executing the sql on master node,
|
||||
// or else it executes the sql on slave node if master-slave configured.
|
||||
func (bs *dbBase) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
|
||||
func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
|
||||
err := (error)(nil)
|
||||
link := (dbLink)(nil)
|
||||
link := (Link)(nil)
|
||||
if len(execOnMaster) > 0 && execOnMaster[0] {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
if link, err = c.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if link, err = bs.db.Slave(); err != nil {
|
||||
if link, err = c.DB.Slave(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doPrepare(link, query)
|
||||
return c.DB.DoPrepare(link, query)
|
||||
}
|
||||
|
||||
// doPrepare calls prepare function on given link object and returns the statement object.
|
||||
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
|
||||
func (c *Core) DoPrepare(link Link, query string) (*sql.Stmt, error) {
|
||||
return link.Prepare(query)
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
return bs.db.doGetAll(nil, query, args...)
|
||||
func (c *Core) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
return c.DB.DoGetAll(nil, query, args...)
|
||||
}
|
||||
|
||||
// doGetAll queries and returns data records from database.
|
||||
func (bs *dbBase) doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) {
|
||||
func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Result, err error) {
|
||||
if link == nil {
|
||||
link, err = bs.db.Slave()
|
||||
link, err = c.DB.Slave()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := bs.doQuery(link, query, args...)
|
||||
rows, err := c.DB.DoQuery(link, query, args...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return bs.db.rowsToResult(rows)
|
||||
return c.DB.rowsToResult(rows)
|
||||
}
|
||||
|
||||
// GetOne queries and returns one record from database.
|
||||
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := bs.GetAll(query, args...)
|
||||
func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := c.DB.GetAll(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -168,8 +179,8 @@ func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
|
||||
// GetStruct queries one record from database and converts it to given struct.
|
||||
// The parameter <pointer> should be a pointer to struct.
|
||||
func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface{}) error {
|
||||
one, err := bs.GetOne(query, args...)
|
||||
func (c *Core) GetStruct(pointer interface{}, query string, args ...interface{}) error {
|
||||
one, err := c.DB.GetOne(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -181,8 +192,8 @@ func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface
|
||||
|
||||
// GetStructs queries records from database and converts them to given struct.
|
||||
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
|
||||
func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interface{}) error {
|
||||
all, err := bs.GetAll(query, args...)
|
||||
func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}) error {
|
||||
all, err := c.DB.GetAll(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -198,7 +209,7 @@ func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interfac
|
||||
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
|
||||
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
|
||||
// for conversion.
|
||||
func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}) error {
|
||||
func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) error {
|
||||
t := reflect.TypeOf(pointer)
|
||||
k := t.Kind()
|
||||
if k != reflect.Ptr {
|
||||
@ -207,9 +218,9 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}
|
||||
k = t.Elem().Kind()
|
||||
switch k {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return bs.db.GetStructs(pointer, query, args...)
|
||||
return c.DB.GetStructs(pointer, query, args...)
|
||||
case reflect.Struct:
|
||||
return bs.db.GetStruct(pointer, query, args...)
|
||||
return c.DB.GetStruct(pointer, query, args...)
|
||||
}
|
||||
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
|
||||
}
|
||||
@ -217,8 +228,8 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}
|
||||
// GetValue queries and returns the field value from database.
|
||||
// The sql should queries only one field from database, or else it returns only one
|
||||
// field of the result.
|
||||
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := bs.GetOne(query, args...)
|
||||
func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := c.DB.GetOne(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -229,13 +240,13 @@ func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// GetCount queries and returns the count from database.
|
||||
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
|
||||
func (c *Core) GetCount(query string, args ...interface{}) (int, error) {
|
||||
// If the query fields do not contains function "COUNT",
|
||||
// it replaces the query string and adds the "COUNT" function to the fields.
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
}
|
||||
value, err := bs.GetValue(query, args...)
|
||||
value, err := c.DB.GetValue(query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -243,8 +254,8 @@ func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
|
||||
}
|
||||
|
||||
// PingMaster pings the master node to check authentication or keeps the connection alive.
|
||||
func (bs *dbBase) PingMaster() error {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
func (c *Core) PingMaster() error {
|
||||
if master, err := c.DB.Master(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return master.Ping()
|
||||
@ -252,8 +263,8 @@ func (bs *dbBase) PingMaster() error {
|
||||
}
|
||||
|
||||
// PingSlave pings the slave node to check authentication or keeps the connection alive.
|
||||
func (bs *dbBase) PingSlave() error {
|
||||
if slave, err := bs.db.Slave(); err != nil {
|
||||
func (c *Core) PingSlave() error {
|
||||
if slave, err := c.DB.Slave(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return slave.Ping()
|
||||
@ -264,13 +275,13 @@ func (bs *dbBase) PingSlave() error {
|
||||
// You should call Commit or Rollback functions of the transaction object
|
||||
// if you no longer use the transaction. Commit or Rollback functions will also
|
||||
// close the transaction automatically.
|
||||
func (bs *dbBase) Begin() (*TX, error) {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
func (c *Core) Begin() (*TX, error) {
|
||||
if master, err := c.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
if tx, err := master.Begin(); err == nil {
|
||||
return &TX{
|
||||
db: bs.db,
|
||||
db: c.DB,
|
||||
tx: tx,
|
||||
master: master,
|
||||
}, nil
|
||||
@ -289,8 +300,8 @@ func (bs *dbBase) Begin() (*TX, error) {
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter <batch> specifies the batch operation count when given data is slice.
|
||||
func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...)
|
||||
func (c *Core) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...)
|
||||
}
|
||||
|
||||
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
|
||||
@ -302,8 +313,8 @@ func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Resu
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter <batch> specifies the batch operation count when given data is slice.
|
||||
func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...)
|
||||
func (c *Core) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...)
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the table.
|
||||
@ -318,8 +329,8 @@ func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sq
|
||||
// The parameter <data> can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// If given data is type of slice, it then does batch replacing, and the optional parameter
|
||||
// <batch> specifies the batch operation count.
|
||||
func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...)
|
||||
func (c *Core) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...)
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
|
||||
@ -333,54 +344,60 @@ func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Res
|
||||
//
|
||||
// If given data is type of slice, it then does batch saving, and the optional parameter
|
||||
// <batch> specifies the batch operation count.
|
||||
func (bs *dbBase) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...)
|
||||
func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...)
|
||||
}
|
||||
|
||||
// doInsert inserts or updates data for given table.
|
||||
//
|
||||
// The parameter <data> can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
|
||||
// Eg:
|
||||
// Data(g.Map{"uid": 10000, "name":"john"})
|
||||
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
|
||||
//
|
||||
// The parameter <option> values are as follows:
|
||||
// 0: insert: just insert, if there's unique/primary key in the data, it returns error;
|
||||
// 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
|
||||
// 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one;
|
||||
// 3: ignore: if there's unique/primary key in the data, it ignores the inserting;
|
||||
func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
func (c *Core) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
var fields []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
var dataMap Map
|
||||
table = bs.db.handleTableName(table)
|
||||
rv := reflect.ValueOf(data)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
table = c.DB.QuotePrefixTableName(table)
|
||||
reflectValue := reflect.ValueOf(data)
|
||||
reflectKind := reflectValue.Kind()
|
||||
if reflectKind == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
reflectKind = reflectValue.Kind()
|
||||
}
|
||||
switch kind {
|
||||
switch reflectKind {
|
||||
case reflect.Slice, reflect.Array:
|
||||
return bs.db.doBatchInsert(link, table, data, option, batch...)
|
||||
return c.DB.DoBatchInsert(link, table, data, option, batch...)
|
||||
case reflect.Map, reflect.Struct:
|
||||
dataMap = varToMapDeep(data)
|
||||
dataMap = DataToMapDeep(data)
|
||||
default:
|
||||
return result, errors.New(fmt.Sprint("unsupported data type:", kind))
|
||||
return result, errors.New(fmt.Sprint("unsupported data type:", reflectKind))
|
||||
}
|
||||
if len(dataMap) == 0 {
|
||||
return nil, errors.New("data cannot be empty")
|
||||
}
|
||||
charL, charR := bs.db.getChars()
|
||||
charL, charR := c.DB.GetChars()
|
||||
for k, v := range dataMap {
|
||||
fields = append(fields, charL+k+charR)
|
||||
values = append(values, "?")
|
||||
params = append(params, v)
|
||||
}
|
||||
operation := getInsertOperationByOption(option)
|
||||
operation := GetInsertOperationByOption(option)
|
||||
updateStr := ""
|
||||
if option == gINSERT_OPTION_SAVE {
|
||||
for k, _ := range dataMap {
|
||||
if len(updateStr) > 0 {
|
||||
updateStr += ","
|
||||
}
|
||||
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
updateStr += fmt.Sprintf(
|
||||
"%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
)
|
||||
@ -388,45 +405,50 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
|
||||
updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr)
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
if link, err = c.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","), updateStr),
|
||||
params...)
|
||||
return c.DB.DoExec(
|
||||
link,
|
||||
fmt.Sprintf(
|
||||
"%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","), updateStr,
|
||||
),
|
||||
params...,
|
||||
)
|
||||
}
|
||||
|
||||
// BatchInsert batch inserts data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (bs *dbBase) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...)
|
||||
func (c *Core) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...)
|
||||
}
|
||||
|
||||
// BatchInsert batch inserts data with ignore option.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (bs *dbBase) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...)
|
||||
func (c *Core) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...)
|
||||
}
|
||||
|
||||
// BatchReplace batch replaces data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (bs *dbBase) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...)
|
||||
func (c *Core) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...)
|
||||
}
|
||||
|
||||
// BatchSave batch replaces data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (bs *dbBase) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...)
|
||||
func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...)
|
||||
}
|
||||
|
||||
// doBatchInsert batch inserts/replaces/saves data.
|
||||
func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
var keys, values []string
|
||||
var params []interface{}
|
||||
table = bs.db.handleTableName(table)
|
||||
table = c.DB.QuotePrefixTableName(table)
|
||||
listMap := (List)(nil)
|
||||
switch v := list.(type) {
|
||||
case Result:
|
||||
@ -449,10 +471,10 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
case reflect.Slice, reflect.Array:
|
||||
listMap = make(List, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
listMap[i] = varToMapDeep(rv.Index(i).Interface())
|
||||
listMap[i] = DataToMapDeep(rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map, reflect.Struct:
|
||||
listMap = List{varToMapDeep(list)}
|
||||
listMap = List{DataToMapDeep(list)}
|
||||
default:
|
||||
return result, errors.New(fmt.Sprint("unsupported list type:", kind))
|
||||
}
|
||||
@ -461,7 +483,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
return result, errors.New("data list cannot be empty")
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
if link, err = c.DB.Master(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -471,20 +493,21 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
keys = append(keys, k)
|
||||
holders = append(holders, "?")
|
||||
}
|
||||
// Prepare the result pointer.
|
||||
// Prepare the batch result pointer.
|
||||
batchResult := new(batchSqlResult)
|
||||
charL, charR := bs.db.getChars()
|
||||
charL, charR := c.DB.GetChars()
|
||||
keysStr := charL + strings.Join(keys, charR+","+charL) + charR
|
||||
valueHolderStr := "(" + strings.Join(holders, ",") + ")"
|
||||
|
||||
operation := getInsertOperationByOption(option)
|
||||
operation := GetInsertOperationByOption(option)
|
||||
updateStr := ""
|
||||
if option == gINSERT_OPTION_SAVE {
|
||||
for _, k := range keys {
|
||||
if len(updateStr) > 0 {
|
||||
updateStr += ","
|
||||
}
|
||||
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
updateStr += fmt.Sprintf(
|
||||
"%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
)
|
||||
@ -504,7 +527,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
}
|
||||
values = append(values, valueHolderStr)
|
||||
if len(values) == batchNum || (i == listMapLen-1 && len(values) > 0) {
|
||||
r, err := bs.db.doExec(
|
||||
r, err := c.DB.DoExec(
|
||||
link,
|
||||
fmt.Sprintf(
|
||||
"%s INTO %s(%s) VALUES%s %s",
|
||||
@ -546,18 +569,18 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}
|
||||
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatWhere(bs.db, condition, args, false)
|
||||
func (c *Core) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatWhere(c.DB, condition, args, false)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return bs.db.doUpdate(nil, table, data, newWhere, newArgs...)
|
||||
return c.DB.DoUpdate(nil, table, data, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
// doUpdate does "UPDATE ... " statement for the table.
|
||||
// Also see Update.
|
||||
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
|
||||
table = bs.db.handleTableName(table)
|
||||
func (c *Core) DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
|
||||
table = c.DB.QuotePrefixTableName(table)
|
||||
updates := ""
|
||||
rv := reflect.ValueOf(data)
|
||||
kind := rv.Kind()
|
||||
@ -569,8 +592,8 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
switch kind {
|
||||
case reflect.Map, reflect.Struct:
|
||||
var fields []string
|
||||
for k, v := range varToMapDeep(data) {
|
||||
fields = append(fields, bs.db.quoteWord(k)+"=?")
|
||||
for k, v := range DataToMapDeep(data) {
|
||||
fields = append(fields, c.DB.QuoteWord(k)+"=?")
|
||||
params = append(params, v)
|
||||
}
|
||||
updates = strings.Join(fields, ",")
|
||||
@ -585,11 +608,15 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
}
|
||||
// If no link passed, it then uses the master link.
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
if link, err = c.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...)
|
||||
return c.DB.DoExec(
|
||||
link,
|
||||
fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition),
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
||||
// Delete does "DELETE FROM ... " statement for the table.
|
||||
@ -603,38 +630,28 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
// "status IN (?)", g.Slice{1,2,3}
|
||||
// "age IN(?,?)", 18, 50
|
||||
// User{ Id : 1, UserName : "john"}
|
||||
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
newWhere, newArgs := formatWhere(bs.db, condition, args, false)
|
||||
func (c *Core) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
newWhere, newArgs := formatWhere(c.DB, condition, args, false)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return bs.db.doDelete(nil, table, newWhere, newArgs...)
|
||||
return c.DB.DoDelete(nil, table, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
// doDelete does "DELETE FROM ... " statement for the table.
|
||||
// Also see Delete.
|
||||
func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) {
|
||||
func (c *Core) DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) {
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
if link, err = c.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
table = bs.db.handleTableName(table)
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
|
||||
}
|
||||
|
||||
// getCache returns the internal cache object.
|
||||
func (bs *dbBase) getCache() *gcache.Cache {
|
||||
return bs.cache
|
||||
}
|
||||
|
||||
// getPrefix returns the table prefix string configured.
|
||||
func (bs *dbBase) getPrefix() string {
|
||||
return bs.prefix
|
||||
table = c.DB.QuotePrefixTableName(table)
|
||||
return c.DB.DoExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
|
||||
}
|
||||
|
||||
// rowsToResult converts underlying data record type sql.Rows to Result type.
|
||||
func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
func (c *Core) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
if !rows.Next() {
|
||||
return nil, nil
|
||||
}
|
||||
@ -671,7 +688,7 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
// it should do a copy of it.
|
||||
v := make([]byte, len(value))
|
||||
copy(v, value)
|
||||
row[columnNames[i]] = gvar.New(bs.db.convertValue(v, columnTypes[i]))
|
||||
row[columnNames[i]] = gvar.New(c.DB.convertValue(v, columnTypes[i]))
|
||||
}
|
||||
}
|
||||
records = append(records, row)
|
||||
@ -682,39 +699,23 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
|
||||
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut".
|
||||
// MarshalJSON implements the interface MarshalJSON for json.Marshal.
|
||||
// It just returns the pointer address.
|
||||
//
|
||||
// Note that, this will automatically checks the table prefix whether already added, if true it does
|
||||
// nothing to the table name, or else adds the prefix to the table name.
|
||||
func (bs *dbBase) handleTableName(table string) string {
|
||||
charLeft, charRight := bs.db.getChars()
|
||||
prefix := bs.db.getPrefix()
|
||||
return doHandleTableName(table, prefix, charLeft, charRight)
|
||||
// Note that this interface implements mainly for workaround for a json infinite loop bug
|
||||
// of Golang version < v1.14.
|
||||
func (c *Core) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf(`%+v`, c)), nil
|
||||
}
|
||||
|
||||
// quoteWord checks given string <s> a word, if true quotes it with security chars of the database
|
||||
// and returns the quoted string; or else return <s> without any change.
|
||||
func (bs *dbBase) quoteWord(s string) string {
|
||||
charLeft, charRight := bs.db.getChars()
|
||||
return doQuoteWord(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// quoteString quotes string with quote chars. Strings like:
|
||||
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
|
||||
func (bs *dbBase) quoteString(s string) string {
|
||||
charLeft, charRight := bs.db.getChars()
|
||||
return doQuoteString(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// printSql outputs the sql object to logger.
|
||||
// writeSqlToLogger outputs the sql object to logger.
|
||||
// It is enabled when configuration "debug" is true.
|
||||
func (bs *dbBase) printSql(v *Sql) {
|
||||
func (c *Core) writeSqlToLogger(v *Sql) {
|
||||
s := fmt.Sprintf("[%d ms] %s", v.End-v.Start, v.Format)
|
||||
if v.Error != nil {
|
||||
s += "\nError: " + v.Error.Error()
|
||||
bs.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s)
|
||||
c.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s)
|
||||
} else {
|
||||
bs.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s)
|
||||
c.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s)
|
||||
}
|
||||
}
|
@ -8,6 +8,7 @@ package gdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gogf/gf/os/gcache"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -114,29 +115,29 @@ func GetDefaultGroup() string {
|
||||
}
|
||||
|
||||
// SetLogger sets the logger for orm.
|
||||
func (bs *dbBase) SetLogger(logger *glog.Logger) {
|
||||
bs.logger = logger
|
||||
func (c *Core) SetLogger(logger *glog.Logger) {
|
||||
c.logger = logger
|
||||
}
|
||||
|
||||
// GetLogger returns the logger of the orm.
|
||||
func (bs *dbBase) GetLogger() *glog.Logger {
|
||||
return bs.logger
|
||||
func (c *Core) GetLogger() *glog.Logger {
|
||||
return c.logger
|
||||
}
|
||||
|
||||
// SetMaxIdleConnCount sets the max idle connection count for underlying connection pool.
|
||||
func (bs *dbBase) SetMaxIdleConnCount(n int) {
|
||||
bs.maxIdleConnCount = n
|
||||
func (c *Core) SetMaxIdleConnCount(n int) {
|
||||
c.maxIdleConnCount = n
|
||||
}
|
||||
|
||||
// SetMaxOpenConnCount sets the max open connection count for underlying connection pool.
|
||||
func (bs *dbBase) SetMaxOpenConnCount(n int) {
|
||||
bs.maxOpenConnCount = n
|
||||
func (c *Core) SetMaxOpenConnCount(n int) {
|
||||
c.maxOpenConnCount = n
|
||||
}
|
||||
|
||||
// SetMaxConnLifetime sets the connection TTL for underlying connection pool.
|
||||
// If parameter <d> <= 0, it means the connection never expires.
|
||||
func (bs *dbBase) SetMaxConnLifetime(d time.Duration) {
|
||||
bs.maxConnLifetime = d
|
||||
func (c *Core) SetMaxConnLifetime(d time.Duration) {
|
||||
c.maxConnLifetime = d
|
||||
}
|
||||
|
||||
// String returns the node as string.
|
||||
@ -155,14 +156,36 @@ func (node *ConfigNode) String() string {
|
||||
}
|
||||
|
||||
// SetDebug enables/disables the debug mode.
|
||||
func (bs *dbBase) SetDebug(debug bool) {
|
||||
if bs.debug.Val() == debug {
|
||||
func (c *Core) SetDebug(debug bool) {
|
||||
if c.debug.Val() == debug {
|
||||
return
|
||||
}
|
||||
bs.debug.Set(debug)
|
||||
c.debug.Set(debug)
|
||||
}
|
||||
|
||||
// getDebug returns the debug value.
|
||||
func (bs *dbBase) getDebug() bool {
|
||||
return bs.debug.Val()
|
||||
// GetDebug returns the debug value.
|
||||
func (c *Core) GetDebug() bool {
|
||||
return c.debug.Val()
|
||||
}
|
||||
|
||||
// GetCache returns the internal cache object.
|
||||
func (c *Core) GetCache() *gcache.Cache {
|
||||
return c.cache
|
||||
}
|
||||
|
||||
// GetPrefix returns the table prefix string configured.
|
||||
func (c *Core) GetPrefix() string {
|
||||
return c.prefix
|
||||
}
|
||||
|
||||
// SetSchema changes the schema for this database connection object.
|
||||
// Importantly note that when schema configuration changed for the database,
|
||||
// it affects all operations on the database object in the future.
|
||||
func (c *Core) SetSchema(schema string) {
|
||||
c.schema.Set(schema)
|
||||
}
|
||||
|
||||
// GetSchema returns the schema configured.
|
||||
func (c *Core) GetSchema() string {
|
||||
return c.schema.Val()
|
||||
}
|
86
database/gdb/gdb_core_utility.go
Normal file
86
database/gdb/gdb_core_utility.go
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// GetMaster acts like function Master but with additional <schema> parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Master.
|
||||
func (c *Core) GetMaster(schema ...string) (*sql.DB, error) {
|
||||
return c.getSqlDb(true, schema...)
|
||||
}
|
||||
|
||||
// GetSlave acts like function Slave but with additional <schema> parameter specifying
|
||||
// the schema for the connection. It is defined for internal usage.
|
||||
// Also see Slave.
|
||||
func (c *Core) GetSlave(schema ...string) (*sql.DB, error) {
|
||||
return c.getSqlDb(false, schema...)
|
||||
}
|
||||
|
||||
// QuoteWord checks given string <s> a word, if true quotes it with security chars of the database
|
||||
// and returns the quoted string; or else return <s> without any change.
|
||||
func (c *Core) QuoteWord(s string) string {
|
||||
charLeft, charRight := c.DB.GetChars()
|
||||
return doQuoteWord(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// QuoteString quotes string with quote chars. Strings like:
|
||||
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
|
||||
func (c *Core) QuoteString(s string) string {
|
||||
charLeft, charRight := c.DB.GetChars()
|
||||
return doQuoteString(s, charLeft, charRight)
|
||||
}
|
||||
|
||||
// QuotePrefixTableName adds prefix string and quotes chars for the table.
|
||||
// It handles table string like:
|
||||
// "user", "user u",
|
||||
// "user,user_detail",
|
||||
// "user u, user_detail ut",
|
||||
// "user as u, user_detail as ut".
|
||||
//
|
||||
// Note that, this will automatically checks the table prefix whether already added,
|
||||
// if true it does nothing to the table name, or else adds the prefix to the table name.
|
||||
func (c *Core) QuotePrefixTableName(table string) string {
|
||||
charLeft, charRight := c.DB.GetChars()
|
||||
return doHandleTableName(table, c.DB.GetPrefix(), charLeft, charRight)
|
||||
}
|
||||
|
||||
// GetChars returns the security char for current database.
|
||||
// It does nothing in default.
|
||||
func (c *Core) GetChars() (charLeft string, charRight string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// HandleSqlBeforeCommit handles the sql before posts it to database.
|
||||
// It does nothing in default.
|
||||
func (c *Core) HandleSqlBeforeCommit(sql string) string {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
//
|
||||
// It does nothing in default.
|
||||
func (c *Core) Tables(schema ...string) (tables []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
//
|
||||
// Note that it returns a map containing the field name and its corresponding fields.
|
||||
// As a map is unsorted, the TableField struct has a "Index" field marks its sequence in the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
|
||||
//
|
||||
// It does nothing in default.
|
||||
func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
return
|
||||
}
|
@ -5,7 +5,7 @@
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
//
|
||||
// Note:
|
||||
// 1. It needs manually import: _ "github.com/lib/pq"
|
||||
// 1. It needs manually import: _ "github.com/denisenkom/go-mssqldb"
|
||||
// 2. It does not support Save/Replace features.
|
||||
// 3. It does not support LastInsertId.
|
||||
|
||||
@ -22,12 +22,21 @@ import (
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
)
|
||||
|
||||
type dbMssql struct {
|
||||
*dbBase
|
||||
// DriverMssql is the driver for SQL server database.
|
||||
type DriverMssql struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
// New creates and returns a database object for SQL server.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverMssql{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for mssql.
|
||||
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
source := ""
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
@ -45,13 +54,13 @@ func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (db *dbMssql) getChars() (charLeft string, charRight string) {
|
||||
// GetChars returns the security char for this type of database.
|
||||
func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
|
||||
func (db *dbMssql) handleSqlBeforeExec(query string) string {
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
|
||||
var index int
|
||||
// Convert place holder char '?' to string "@px".
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
@ -59,10 +68,10 @@ func (db *dbMssql) handleSqlBeforeExec(query string) string {
|
||||
return fmt.Sprintf("@p%d", index)
|
||||
})
|
||||
str, _ = gregex.ReplaceString("\"", "", str)
|
||||
return db.parseSql(str)
|
||||
return d.parseSql(str), args
|
||||
}
|
||||
|
||||
func (db *dbMssql) parseSql(sql string) string {
|
||||
func (d *DriverMssql) parseSql(sql string) string {
|
||||
// SELECT * FROM USER WHERE ID=1 LIMIT 1
|
||||
if m, _ := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, sql); len(m) > 1 {
|
||||
return fmt.Sprintf(`SELECT TOP 1 %s`, m[1])
|
||||
@ -163,44 +172,45 @@ func (db *dbMssql) parseSql(sql string) string {
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (db *dbMssql) Tables(schema ...string) (tables []string, err error) {
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) {
|
||||
var result Result
|
||||
link, err := db.getSlave(schema...)
|
||||
link, err := d.DB.GetSlave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err = db.doGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
|
||||
result, err = d.DB.DoGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, m := range result {
|
||||
for _, v := range m {
|
||||
tables = append(tables, strings.ToLower(v.String()))
|
||||
tables = append(tables, v.String())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
func (db *dbMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
func (d *DriverMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function TableFields supports only single table operations")
|
||||
}
|
||||
checkSchema := db.schema.Val()
|
||||
checkSchema := d.DB.GetSchema()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := db.cache.GetOrSetFunc(
|
||||
v := d.DB.GetCache().GetOrSetFunc(
|
||||
fmt.Sprintf(`mssql_table_fields_%s_%s`, table, checkSchema), func() interface{} {
|
||||
var result Result
|
||||
var link *sql.DB
|
||||
link, err = db.getSlave(checkSchema)
|
||||
link, err = d.DB.GetSlave(checkSchema)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result, err = db.doGetAll(link, fmt.Sprintf(`
|
||||
result, err = d.DB.DoGetAll(link, fmt.Sprintf(`
|
||||
SELECT c.name as FIELD, CASE t.name
|
||||
WHEN 'numeric' THEN t.name + '(' + convert(varchar(20),c.xprec) + ',' + convert(varchar(20),c.xscale) + ')'
|
||||
WHEN 'char' THEN t.name + '(' + convert(varchar(20),c.length)+ ')'
|
@ -12,15 +12,24 @@ import (
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
|
||||
_ "github.com/gf-third/mysql"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type dbMysql struct {
|
||||
*dbBase
|
||||
// DriverMysql is the driver for mysql database.
|
||||
type DriverMysql struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
// New creates and returns a database object for mysql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverMysql{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for mysql.
|
||||
func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
@ -31,31 +40,32 @@ func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
)
|
||||
}
|
||||
intlog.Printf("Open: %s", source)
|
||||
if db, err := sql.Open("gf-mysql", source); err == nil {
|
||||
if db, err := sql.Open("mysql", source); err == nil {
|
||||
return db, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (db *dbMysql) getChars() (charLeft string, charRight string) {
|
||||
// GetChars returns the security char for this type of database.
|
||||
func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec handles the sql before posts it to database.
|
||||
func (db *dbMysql) handleSqlBeforeExec(sql string) string {
|
||||
return sql
|
||||
// HandleSqlBeforeCommit handles the sql before posts it to database.
|
||||
func (d *DriverMysql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
return sql, args
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (bs *dbBase) Tables(schema ...string) (tables []string, err error) {
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) {
|
||||
var result Result
|
||||
link, err := bs.db.getSlave(schema...)
|
||||
link, err := d.DB.GetSlave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = bs.db.doGetAll(link, `SHOW TABLES`)
|
||||
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -73,27 +83,27 @@ func (bs *dbBase) Tables(schema ...string) (tables []string, err error) {
|
||||
// As a map is unsorted, the TableField struct has a "Index" field marks its sequence in the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
|
||||
func (bs *dbBase) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
func (d *DriverMysql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function TableFields supports only single table operations")
|
||||
}
|
||||
checkSchema := bs.schema.Val()
|
||||
checkSchema := d.schema.Val()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := bs.cache.GetOrSetFunc(
|
||||
v := d.cache.GetOrSetFunc(
|
||||
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
|
||||
func() interface{} {
|
||||
var result Result
|
||||
var link *sql.DB
|
||||
link, err = bs.db.getSlave(checkSchema)
|
||||
link, err = d.DB.GetSlave(checkSchema)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result, err = bs.doGetAll(
|
||||
result, err = d.DB.DoGetAll(
|
||||
link,
|
||||
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, bs.db.quoteWord(table)),
|
||||
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil
|
@ -24,8 +24,9 @@ import (
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
)
|
||||
|
||||
type dbOracle struct {
|
||||
*dbBase
|
||||
// DriverOracle is the driver for oracle database.
|
||||
type DriverOracle struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
const (
|
||||
@ -33,8 +34,16 @@ const (
|
||||
tableAlias2 = "GFORM2"
|
||||
)
|
||||
|
||||
// New creates and returns a database object for oracle.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverOracle{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for oracle.
|
||||
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
func (d *DriverOracle) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
@ -49,13 +58,13 @@ func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (db *dbOracle) getChars() (charLeft string, charRight string) {
|
||||
// GetChars returns the security char for this type of database.
|
||||
func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
|
||||
func (db *dbOracle) handleSqlBeforeExec(query string) string {
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
|
||||
var index int
|
||||
// Convert place holder char '?' to string ":x".
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
@ -63,10 +72,10 @@ func (db *dbOracle) handleSqlBeforeExec(query string) string {
|
||||
return fmt.Sprintf(":%d", index)
|
||||
})
|
||||
str, _ = gregex.ReplaceString("\"", "", str)
|
||||
return db.parseSql(str)
|
||||
return d.parseSql(str), args
|
||||
}
|
||||
|
||||
func (db *dbOracle) parseSql(sql string) string {
|
||||
func (d *DriverOracle) parseSql(sql string) string {
|
||||
patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
|
||||
if gregex.IsMatchString(patten, sql) == false {
|
||||
return sql
|
||||
@ -123,36 +132,37 @@ func (db *dbOracle) parseSql(sql string) string {
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (db *dbOracle) Tables(schema ...string) (tables []string, err error) {
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
// Note that it ignores the parameter <schema> in oracle database, as it is not necessary.
|
||||
func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) {
|
||||
var result Result
|
||||
|
||||
result, err = db.doGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
|
||||
result, err = d.DB.DoGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, m := range result {
|
||||
for _, v := range m {
|
||||
tables = append(tables, strings.ToLower(v.String()))
|
||||
tables = append(tables, v.String())
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
func (db *dbOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
func (d *DriverOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function TableFields supports only single table operations")
|
||||
}
|
||||
checkSchema := db.schema.Val()
|
||||
checkSchema := d.DB.GetSchema()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := db.cache.GetOrSetFunc(
|
||||
v := d.DB.GetCache().GetOrSetFunc(
|
||||
fmt.Sprintf(`oracle_table_fields_%s_%s`, table, checkSchema),
|
||||
func() interface{} {
|
||||
result := (Result)(nil)
|
||||
result, err = db.GetAll(fmt.Sprintf(`
|
||||
result, err = d.DB.GetAll(fmt.Sprintf(`
|
||||
SELECT COLUMN_NAME AS FIELD, CASE DATA_TYPE
|
||||
WHEN 'NUMBER' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')'
|
||||
WHEN 'FLOAT' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')'
|
||||
@ -177,11 +187,11 @@ func (db *dbOracle) TableFields(table string, schema ...string) (fields map[stri
|
||||
return
|
||||
}
|
||||
|
||||
func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) {
|
||||
func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) {
|
||||
table = strings.ToUpper(table)
|
||||
v := db.cache.GetOrSetFunc("table_unique_index_"+table, func() interface{} {
|
||||
v := d.DB.GetCache().GetOrSetFunc("table_unique_index_"+table, func() interface{} {
|
||||
res := (Result)(nil)
|
||||
res, err = db.GetAll(fmt.Sprintf(`
|
||||
res, err = d.DB.GetAll(fmt.Sprintf(`
|
||||
SELECT INDEX_NAME,COLUMN_NAME,CHAR_LENGTH FROM USER_IND_COLUMNS
|
||||
WHERE TABLE_NAME = '%s'
|
||||
AND INDEX_NAME IN(SELECT INDEX_NAME FROM USER_INDEXES WHERE TABLE_NAME='%s' AND UNIQUENESS='UNIQUE')
|
||||
@ -203,7 +213,7 @@ func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[str
|
||||
return
|
||||
}
|
||||
|
||||
func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
var fields []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
@ -218,11 +228,11 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
case reflect.Slice:
|
||||
fallthrough
|
||||
case reflect.Array:
|
||||
return db.db.doBatchInsert(link, table, data, option, batch...)
|
||||
return d.DB.DoBatchInsert(link, table, data, option, batch...)
|
||||
case reflect.Map:
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
dataMap = varToMapDeep(data)
|
||||
dataMap = DataToMapDeep(data)
|
||||
default:
|
||||
return result, errors.New(fmt.Sprint("unsupported data type:", kind))
|
||||
}
|
||||
@ -231,7 +241,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
indexMap := make(map[string]string)
|
||||
indexExists := false
|
||||
if option != gINSERT_OPTION_DEFAULT {
|
||||
index, err := db.getTableUniqueIndex(table)
|
||||
index, err := d.getTableUniqueIndex(table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -253,7 +263,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
onStr := make([]string, 0)
|
||||
updateStr := make([]string, 0)
|
||||
|
||||
charL, charR := db.db.getChars()
|
||||
charL, charR := d.DB.GetChars()
|
||||
for k, v := range dataMap {
|
||||
k = strings.ToUpper(k)
|
||||
|
||||
@ -279,7 +289,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
}
|
||||
|
||||
if link == nil {
|
||||
if link, err = db.db.Master(); err != nil {
|
||||
if link, err = d.DB.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -294,9 +304,9 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2,
|
||||
strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","),
|
||||
)
|
||||
return db.db.doExec(link, tmp, params...)
|
||||
return d.DB.DoExec(link, tmp, params...)
|
||||
case gINSERT_OPTION_IGNORE:
|
||||
return db.db.doExec(link,
|
||||
return d.DB.DoExec(link,
|
||||
fmt.Sprintf(
|
||||
"INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)",
|
||||
table, strings.Join(indexs, ","), table, strings.Join(fields, ","), strings.Join(values, ","),
|
||||
@ -305,7 +315,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
}
|
||||
}
|
||||
|
||||
return db.db.doExec(
|
||||
return d.DB.DoExec(
|
||||
link,
|
||||
fmt.Sprintf(
|
||||
"INSERT INTO %s(%s) VALUES(%s)",
|
||||
@ -314,7 +324,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
|
||||
params...)
|
||||
}
|
||||
|
||||
func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
@ -342,12 +352,12 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
|
||||
case reflect.Array:
|
||||
listMap = make(List, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
listMap[i] = varToMapDeep(rv.Index(i).Interface())
|
||||
listMap[i] = DataToMapDeep(rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map:
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
listMap = List{Map(varToMapDeep(list))}
|
||||
listMap = List{Map(DataToMapDeep(list))}
|
||||
default:
|
||||
return result, errors.New(fmt.Sprint("unsupported list type:", kind))
|
||||
}
|
||||
@ -357,7 +367,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
|
||||
return result, errors.New("empty data list")
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = db.db.Master(); err != nil {
|
||||
if link, err = d.DB.Master(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -368,14 +378,14 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
|
||||
holders = append(holders, "?")
|
||||
}
|
||||
batchResult := new(batchSqlResult)
|
||||
charL, charR := db.db.getChars()
|
||||
charL, charR := d.DB.GetChars()
|
||||
keyStr := charL + strings.Join(keys, charL+","+charR) + charR
|
||||
valueHolderStr := strings.Join(holders, ",")
|
||||
|
||||
// 当操作类型非insert时调用单笔的insert功能
|
||||
if option != gINSERT_OPTION_DEFAULT {
|
||||
for _, v := range listMap {
|
||||
r, err := db.doInsert(link, table, v, option, 1)
|
||||
r, err := d.DB.DoInsert(link, table, v, option, 1)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
@ -402,10 +412,9 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
|
||||
params = append(params, listMap[i][k])
|
||||
}
|
||||
values = append(values, valueHolderStr)
|
||||
|
||||
intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr))
|
||||
if len(intoStr) == batchNum {
|
||||
r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
|
||||
r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
@ -421,7 +430,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
|
||||
}
|
||||
// 处理最后不构成指定批量的数据
|
||||
if len(intoStr) > 0 {
|
||||
r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
|
||||
r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
@ -21,12 +21,21 @@ import (
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
)
|
||||
|
||||
type dbPgsql struct {
|
||||
*dbBase
|
||||
// DriverPgsql is the driver for postgresql database.
|
||||
type DriverPgsql struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
// New creates and returns a database object for postgresql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverPgsql{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for pgsql.
|
||||
func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
func (d *DriverPgsql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
@ -44,13 +53,13 @@ func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (db *dbPgsql) getChars() (charLeft string, charRight string) {
|
||||
// GetChars returns the security char for this type of database.
|
||||
func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
|
||||
func (db *dbPgsql) handleSqlBeforeExec(sql string) string {
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
var index int
|
||||
// Convert place holder char '?' to string "$x".
|
||||
sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
@ -58,13 +67,14 @@ func (db *dbPgsql) handleSqlBeforeExec(sql string) string {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
})
|
||||
sql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $1 OFFSET $2`, sql)
|
||||
return sql
|
||||
return sql, args
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) {
|
||||
var result Result
|
||||
link, err := db.getSlave(schema...)
|
||||
link, err := d.DB.GetSlave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -73,7 +83,7 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0])
|
||||
}
|
||||
result, err = db.doGetAll(link, query)
|
||||
result, err = d.DB.DoGetAll(link, query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -86,25 +96,25 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
func (db *dbPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
func (d *DriverPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function TableFields supports only single table operations")
|
||||
}
|
||||
table, _ = gregex.ReplaceString("\"", "", table)
|
||||
checkSchema := db.schema.Val()
|
||||
checkSchema := d.DB.GetSchema()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := db.cache.GetOrSetFunc(
|
||||
v := d.DB.GetCache().GetOrSetFunc(
|
||||
fmt.Sprintf(`pgsql_table_fields_%s_%s`, table, checkSchema), func() interface{} {
|
||||
var result Result
|
||||
var link *sql.DB
|
||||
link, err = db.getSlave(checkSchema)
|
||||
link, err = d.DB.GetSlave(checkSchema)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result, err = db.doGetAll(link, fmt.Sprintf(`
|
||||
result, err = d.DB.DoGetAll(link, fmt.Sprintf(`
|
||||
SELECT a.attname AS field, t.typname AS type FROM pg_class c, pg_attribute a
|
||||
LEFT OUTER JOIN pg_description b ON a.attrelid=b.objoid AND a.attnum = b.objsubid,pg_type t
|
||||
WHERE c.relname = '%s' and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid
|
@ -16,12 +16,21 @@ import (
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
)
|
||||
|
||||
type dbSqlite struct {
|
||||
*dbBase
|
||||
// DriverSqlite is the driver for sqlite database.
|
||||
type DriverSqlite struct {
|
||||
*Core
|
||||
}
|
||||
|
||||
// New creates and returns a database object for sqlite.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *DriverSqlite) New(core *Core, node *ConfigNode) (DB, error) {
|
||||
return &DriverSqlite{
|
||||
Core: core,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for sqlite.
|
||||
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
@ -36,30 +45,31 @@ func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (db *dbSqlite) getChars() (charLeft string, charRight string) {
|
||||
// GetChars returns the security char for this type of database.
|
||||
func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
}
|
||||
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
// @todo 需要增加对Save方法的支持,可使用正则来实现替换,
|
||||
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
|
||||
func (d *DriverSqlite) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
return sql, args
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
// It's mainly used in cli tool chain for automatically generating the models.
|
||||
// TODO
|
||||
func (db *dbSqlite) Tables(schema ...string) (tables []string, err error) {
|
||||
func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
// TODO
|
||||
func (db *dbSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
func (d *DriverSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function TableFields supports only single table operations")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
|
||||
// @todo 需要增加对Save方法的支持,可使用正则来实现替换,
|
||||
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
|
||||
func (db *dbSqlite) handleSqlBeforeExec(sql string) string {
|
||||
return sql
|
||||
}
|
@ -51,7 +51,54 @@ var (
|
||||
quoteWordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
|
||||
)
|
||||
|
||||
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
|
||||
// GetInsertOperationByOption returns proper insert option with given parameter <option>.
|
||||
func GetInsertOperationByOption(option int) string {
|
||||
var operator string
|
||||
switch option {
|
||||
case gINSERT_OPTION_REPLACE:
|
||||
operator = "REPLACE"
|
||||
case gINSERT_OPTION_IGNORE:
|
||||
operator = "INSERT IGNORE"
|
||||
default:
|
||||
operator = "INSERT"
|
||||
}
|
||||
return operator
|
||||
}
|
||||
|
||||
// DataToMapDeep converts struct object to map type recursively.
|
||||
func DataToMapDeep(obj interface{}) map[string]interface{} {
|
||||
data := gconv.Map(obj, ORM_TAG_FOR_STRUCT)
|
||||
for key, value := range data {
|
||||
rv := reflect.ValueOf(value)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
// The underlying driver supports time.Time/*time.Time types.
|
||||
if _, ok := value.(time.Time); ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := value.(*time.Time); ok {
|
||||
continue
|
||||
}
|
||||
// Use string conversion in default.
|
||||
if s, ok := value.(apiString); ok {
|
||||
data[key] = s.String()
|
||||
continue
|
||||
}
|
||||
delete(data, key)
|
||||
for k, v := range DataToMapDeep(value) {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// QuotePrefixTableName adds prefix string and quote chars for the table. It handles table string like:
|
||||
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut", "user.user u".
|
||||
//
|
||||
// Note that, this will automatically checks the table prefix whether already added, if true it does
|
||||
@ -196,7 +243,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
|
||||
newArgs = formatWhereInterfaces(db, gconv.Interfaces(where), buffer, newArgs)
|
||||
|
||||
case reflect.Map:
|
||||
for key, value := range varToMapDeep(where) {
|
||||
for key, value := range DataToMapDeep(where) {
|
||||
if omitEmpty && empty.IsEmpty(value) {
|
||||
continue
|
||||
}
|
||||
@ -218,7 +265,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
|
||||
})
|
||||
break
|
||||
}
|
||||
for key, value := range varToMapDeep(where) {
|
||||
for key, value := range DataToMapDeep(where) {
|
||||
if omitEmpty && empty.IsEmpty(value) {
|
||||
continue
|
||||
}
|
||||
@ -269,7 +316,7 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new
|
||||
|
||||
// formatWhereKeyValue handles each key-value pair of the parameter map.
|
||||
func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} {
|
||||
key = db.quoteWord(key)
|
||||
key = db.QuoteWord(key)
|
||||
if buffer.Len() > 0 {
|
||||
buffer.WriteString(" AND ")
|
||||
}
|
||||
@ -314,39 +361,6 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
|
||||
return newArgs
|
||||
}
|
||||
|
||||
// varToMapDeep converts struct object to map type recursively.
|
||||
func varToMapDeep(obj interface{}) map[string]interface{} {
|
||||
data := gconv.Map(obj, ORM_TAG_FOR_STRUCT)
|
||||
for key, value := range data {
|
||||
rv := reflect.ValueOf(value)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
// The underlying driver supports time.Time/*time.Time types.
|
||||
if _, ok := value.(time.Time); ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := value.(*time.Time); ok {
|
||||
continue
|
||||
}
|
||||
// Use string conversion in default.
|
||||
if s, ok := value.(apiString); ok {
|
||||
data[key] = s.String()
|
||||
continue
|
||||
}
|
||||
delete(data, key)
|
||||
for k, v := range varToMapDeep(value) {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// handleArguments is a nice function which handles the query and its arguments before committing to
|
||||
// underlying driver.
|
||||
func handleArguments(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
|
||||
@ -422,20 +436,6 @@ func formatError(err error, query string, args ...interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// getInsertOperationByOption returns proper insert option with given parameter <option>.
|
||||
func getInsertOperationByOption(option int) string {
|
||||
var operator string
|
||||
switch option {
|
||||
case gINSERT_OPTION_REPLACE:
|
||||
operator = "REPLACE"
|
||||
case gINSERT_OPTION_IGNORE:
|
||||
operator = "INSERT IGNORE"
|
||||
default:
|
||||
operator = "INSERT"
|
||||
}
|
||||
return operator
|
||||
}
|
||||
|
||||
// bindArgsToQuery binds the arguments to the query string and returns a complete
|
||||
// sql string, just for debugging.
|
||||
func bindArgsToQuery(query string, args []interface{}) string {
|
||||
|
@ -68,10 +68,10 @@ const (
|
||||
// Table creates and returns a new ORM model from given schema.
|
||||
// The parameter <tables> can be more than one table names, like :
|
||||
// "user", "user u", "user, user_detail", "user u, user_detail ud"
|
||||
func (bs *dbBase) Table(table string) *Model {
|
||||
table = bs.db.handleTableName(table)
|
||||
func (c *Core) Table(table string) *Model {
|
||||
table = c.DB.QuotePrefixTableName(table)
|
||||
return &Model{
|
||||
db: bs.db,
|
||||
db: c.DB,
|
||||
tablesInit: table,
|
||||
tables: table,
|
||||
fields: "*",
|
||||
@ -82,23 +82,23 @@ func (bs *dbBase) Table(table string) *Model {
|
||||
}
|
||||
}
|
||||
|
||||
// Model is alias of dbBase.Table.
|
||||
// See dbBase.Table.
|
||||
func (bs *dbBase) Model(table string) *Model {
|
||||
return bs.db.Table(table)
|
||||
// Model is alias of Core.Table.
|
||||
// See Core.Table.
|
||||
func (c *Core) Model(table string) *Model {
|
||||
return c.DB.Table(table)
|
||||
}
|
||||
|
||||
// From is alias of dbBase.Table.
|
||||
// See dbBase.Table.
|
||||
// From is alias of Core.Table.
|
||||
// See Core.Table.
|
||||
// Deprecated.
|
||||
func (bs *dbBase) From(table string) *Model {
|
||||
return bs.db.Table(table)
|
||||
func (c *Core) From(table string) *Model {
|
||||
return c.DB.Table(table)
|
||||
}
|
||||
|
||||
// Table acts like dbBase.Table except it operates on transaction.
|
||||
// See dbBase.Table.
|
||||
// Table acts like Core.Table except it operates on transaction.
|
||||
// See Core.Table.
|
||||
func (tx *TX) Table(table string) *Model {
|
||||
table = tx.db.handleTableName(table)
|
||||
table = tx.db.QuotePrefixTableName(table)
|
||||
return &Model{
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
@ -217,21 +217,21 @@ func (m *Model) getModel() *Model {
|
||||
// LeftJoin does "LEFT JOIN ... ON ..." statement on the model.
|
||||
func (m *Model) LeftJoin(table string, on string) *Model {
|
||||
model := m.getModel()
|
||||
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.handleTableName(table), on)
|
||||
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
|
||||
return model
|
||||
}
|
||||
|
||||
// RightJoin does "RIGHT JOIN ... ON ..." statement on the model.
|
||||
func (m *Model) RightJoin(table string, on string) *Model {
|
||||
model := m.getModel()
|
||||
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.handleTableName(table), on)
|
||||
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
|
||||
return model
|
||||
}
|
||||
|
||||
// InnerJoin does "INNER JOIN ... ON ..." statement on the model.
|
||||
func (m *Model) InnerJoin(table string, on string) *Model {
|
||||
model := m.getModel()
|
||||
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.handleTableName(table), on)
|
||||
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
|
||||
return model
|
||||
}
|
||||
|
||||
@ -403,7 +403,7 @@ func (m *Model) Or(where interface{}, args ...interface{}) *Model {
|
||||
// Group sets the "GROUP BY" statement for the model.
|
||||
func (m *Model) Group(groupBy string) *Model {
|
||||
model := m.getModel()
|
||||
model.groupBy = m.db.quoteString(groupBy)
|
||||
model.groupBy = m.db.QuoteString(groupBy)
|
||||
return model
|
||||
}
|
||||
|
||||
@ -417,7 +417,7 @@ func (m *Model) GroupBy(groupBy string) *Model {
|
||||
// Order sets the "ORDER BY" statement for the model.
|
||||
func (m *Model) Order(orderBy string) *Model {
|
||||
model := m.getModel()
|
||||
model.orderBy = m.db.quoteString(orderBy)
|
||||
model.orderBy = m.db.QuoteString(orderBy)
|
||||
return model
|
||||
}
|
||||
|
||||
@ -540,11 +540,11 @@ func (m *Model) Data(data ...interface{}) *Model {
|
||||
case reflect.Slice, reflect.Array:
|
||||
list := make(List, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
list[i] = varToMapDeep(rv.Index(i).Interface())
|
||||
list[i] = DataToMapDeep(rv.Index(i).Interface())
|
||||
}
|
||||
model.data = list
|
||||
case reflect.Map, reflect.Struct:
|
||||
model.data = varToMapDeep(data[0])
|
||||
model.data = DataToMapDeep(data[0])
|
||||
default:
|
||||
model.data = data[0]
|
||||
}
|
||||
@ -586,7 +586,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.
|
||||
if m.batch > 0 {
|
||||
batch = m.batch
|
||||
}
|
||||
return m.db.doBatchInsert(
|
||||
return m.db.DoBatchInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(list),
|
||||
@ -595,7 +595,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.
|
||||
)
|
||||
} else if data, ok := m.data.(Map); ok {
|
||||
// Single insert.
|
||||
return m.db.doInsert(
|
||||
return m.db.DoInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(data),
|
||||
@ -626,7 +626,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
|
||||
if m.batch > 0 {
|
||||
batch = m.batch
|
||||
}
|
||||
return m.db.doBatchInsert(
|
||||
return m.db.DoBatchInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(list),
|
||||
@ -635,7 +635,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
|
||||
)
|
||||
} else if data, ok := m.data.(Map); ok {
|
||||
// Single insert.
|
||||
return m.db.doInsert(
|
||||
return m.db.DoInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(data),
|
||||
@ -669,7 +669,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
|
||||
if m.batch > 0 {
|
||||
batch = m.batch
|
||||
}
|
||||
return m.db.doBatchInsert(
|
||||
return m.db.DoBatchInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(list),
|
||||
@ -678,7 +678,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
|
||||
)
|
||||
} else if data, ok := m.data.(Map); ok {
|
||||
// Single save.
|
||||
return m.db.doInsert(
|
||||
return m.db.DoInsert(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(data),
|
||||
@ -712,7 +712,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
|
||||
return nil, errors.New("updating table with empty data")
|
||||
}
|
||||
condition, conditionArgs := m.formatCondition(false)
|
||||
return m.db.doUpdate(
|
||||
return m.db.DoUpdate(
|
||||
m.getLink(true),
|
||||
m.tables,
|
||||
m.filterDataForInsertOrUpdate(m.data),
|
||||
@ -734,7 +734,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
|
||||
}
|
||||
}()
|
||||
condition, conditionArgs := m.formatCondition(false)
|
||||
return m.db.doDelete(m.getLink(true), m.tables, condition, conditionArgs...)
|
||||
return m.db.DoDelete(m.getLink(true), m.tables, condition, conditionArgs...)
|
||||
}
|
||||
|
||||
// Select is alias of Model.All.
|
||||
@ -1045,7 +1045,7 @@ func (m *Model) doFilterDataMapForInsertOrUpdate(data Map, allowOmitEmpty bool)
|
||||
|
||||
// getLink returns the underlying database link object with configured <linkType> attribute.
|
||||
// The parameter <master> specifies whether using the master node if master-slave configured.
|
||||
func (m *Model) getLink(master bool) dbLink {
|
||||
func (m *Model) getLink(master bool) Link {
|
||||
if m.tx != nil {
|
||||
return m.tx.tx
|
||||
}
|
||||
@ -1059,10 +1059,16 @@ func (m *Model) getLink(master bool) dbLink {
|
||||
}
|
||||
switch linkType {
|
||||
case gLINK_TYPE_MASTER:
|
||||
link, _ := m.db.getMaster(m.schema)
|
||||
link, err := m.db.GetMaster(m.schema)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return link
|
||||
case gLINK_TYPE_SLAVE:
|
||||
link, _ := m.db.getSlave(m.schema)
|
||||
link, err := m.db.GetSlave(m.schema)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return link
|
||||
}
|
||||
return nil
|
||||
@ -1077,17 +1083,17 @@ func (m *Model) getAll(query string, args ...interface{}) (result Result, err er
|
||||
if len(cacheKey) == 0 {
|
||||
cacheKey = query + "/" + gconv.String(args)
|
||||
}
|
||||
if v := m.db.getCache().Get(cacheKey); v != nil {
|
||||
if v := m.db.GetCache().Get(cacheKey); v != nil {
|
||||
return v.(Result), nil
|
||||
}
|
||||
}
|
||||
result, err = m.db.doGetAll(m.getLink(false), query, args...)
|
||||
result, err = m.db.DoGetAll(m.getLink(false), query, args...)
|
||||
// Cache the result.
|
||||
if len(cacheKey) > 0 && err == nil {
|
||||
if m.cacheDuration < 0 {
|
||||
m.db.getCache().Remove(cacheKey)
|
||||
m.db.GetCache().Remove(cacheKey)
|
||||
} else {
|
||||
m.db.getCache().Set(cacheKey, result, m.cacheDuration)
|
||||
m.db.GetCache().Set(cacheKey, result, m.cacheDuration)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -1113,7 +1119,7 @@ func (m *Model) getPrimaryKey() string {
|
||||
// checkAndRemoveCache checks and remove the cache if necessary.
|
||||
func (m *Model) checkAndRemoveCache() {
|
||||
if m.cacheEnabled && m.cacheDuration < 0 && len(m.cacheName) > 0 {
|
||||
m.db.getCache().Remove(m.cacheName)
|
||||
m.db.GetCache().Remove(m.cacheName)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -14,9 +14,9 @@ type Schema struct {
|
||||
}
|
||||
|
||||
// Schema creates and returns a schema.
|
||||
func (bs *dbBase) Schema(schema string) *Schema {
|
||||
func (c *Core) Schema(schema string) *Schema {
|
||||
return &Schema{
|
||||
db: bs.db,
|
||||
db: c.DB,
|
||||
schema: schema,
|
||||
}
|
||||
}
|
||||
@ -44,8 +44,8 @@ func (s *Schema) Table(table string) *Model {
|
||||
return m
|
||||
}
|
||||
|
||||
// Model is alias of dbBase.Table.
|
||||
// See dbBase.Table.
|
||||
// Model is alias of Core.Table.
|
||||
// See Core.Table.
|
||||
func (s *Schema) Model(table string) *Model {
|
||||
return s.Table(table)
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
|
||||
// convertValue automatically checks and converts field value from database type
|
||||
// to golang variable type.
|
||||
func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{} {
|
||||
func (c *Core) convertValue(fieldValue []byte, fieldType string) interface{} {
|
||||
t, _ := gregex.ReplaceString(`\(.+\)`, "", fieldType)
|
||||
t = strings.ToLower(t)
|
||||
switch t {
|
||||
@ -106,10 +106,10 @@ func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{}
|
||||
}
|
||||
|
||||
// filterFields removes all key-value pairs which are not the field of given table.
|
||||
func (bs *dbBase) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} {
|
||||
func (c *Core) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} {
|
||||
// It must use data copy here to avoid its changing the origin data map.
|
||||
newDataMap := make(map[string]interface{}, len(data))
|
||||
if fields, err := bs.db.TableFields(table, schema); err == nil {
|
||||
if fields, err := c.DB.TableFields(table, schema); err == nil {
|
||||
for k, v := range data {
|
||||
if _, ok := fields[k]; ok {
|
||||
newDataMap[k] = v
|
||||
|
@ -32,15 +32,15 @@ func (tx *TX) Rollback() error {
|
||||
}
|
||||
|
||||
// Query does query operation on transaction.
|
||||
// See dbBase.Query.
|
||||
// See Core.Query.
|
||||
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
return tx.db.doQuery(tx.tx, query, args...)
|
||||
return tx.db.DoQuery(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// Exec does none query operation on transaction.
|
||||
// See dbBase.Exec.
|
||||
// See Core.Exec.
|
||||
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.doExec(tx.tx, query, args...)
|
||||
return tx.db.DoExec(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
@ -49,7 +49,7 @@ func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
// The caller must call the statement's Close method
|
||||
// when the statement is no longer needed.
|
||||
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.db.doPrepare(tx.tx, query)
|
||||
return tx.db.DoPrepare(tx.tx, query)
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
@ -154,7 +154,7 @@ func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
|
||||
//
|
||||
// The parameter <batch> specifies the batch operation count when given data is slice.
|
||||
func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...)
|
||||
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...)
|
||||
}
|
||||
|
||||
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
|
||||
@ -167,7 +167,7 @@ func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result,
|
||||
//
|
||||
// The parameter <batch> specifies the batch operation count when given data is slice.
|
||||
func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...)
|
||||
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...)
|
||||
}
|
||||
|
||||
// Replace does "REPLACE INTO ..." statement for the table.
|
||||
@ -183,7 +183,7 @@ func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Re
|
||||
// If given data is type of slice, it then does batch replacing, and the optional parameter
|
||||
// <batch> specifies the batch operation count.
|
||||
func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...)
|
||||
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...)
|
||||
}
|
||||
|
||||
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
|
||||
@ -198,31 +198,31 @@ func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result,
|
||||
// If given data is type of slice, it then does batch saving, and the optional parameter
|
||||
// <batch> specifies the batch operation count.
|
||||
func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...)
|
||||
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...)
|
||||
}
|
||||
|
||||
// BatchInsert batch inserts data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...)
|
||||
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...)
|
||||
}
|
||||
|
||||
// BatchInsert batch inserts data with ignore option.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...)
|
||||
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...)
|
||||
}
|
||||
|
||||
// BatchReplace batch replaces data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...)
|
||||
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...)
|
||||
}
|
||||
|
||||
// BatchSave batch replaces data.
|
||||
// The parameter <list> must be type of slice of map or struct.
|
||||
func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...)
|
||||
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...)
|
||||
}
|
||||
|
||||
// Update does "UPDATE ... " statement for the table.
|
||||
@ -244,7 +244,7 @@ func (tx *TX) Update(table string, data interface{}, condition interface{}, args
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return tx.db.doUpdate(tx.tx, table, data, newWhere, newArgs...)
|
||||
return tx.db.DoUpdate(tx.tx, table, data, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
// Delete does "DELETE FROM ... " statement for the table.
|
||||
@ -263,5 +263,5 @@ func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return tx.db.doDelete(tx.tx, table, newWhere, newArgs...)
|
||||
return tx.db.DoDelete(tx.tx, table, newWhere, newArgs...)
|
||||
}
|
||||
|
74
database/gdb/gdb_unit_z_driver_test.go
Normal file
74
database/gdb/gdb_unit_z_driver_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/container/gtype"
|
||||
"github.com/gogf/gf/database/gdb"
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MyDriver is a custom database driver, which is used for testing only.
|
||||
// For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver
|
||||
// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit.
|
||||
// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and
|
||||
// then gdb.DriverMysql.HandleSqlBeforeCommit.
|
||||
// You can call it sql "HOOK" or "HiJack" as your will.
|
||||
type MyDriver struct {
|
||||
*gdb.DriverMysql
|
||||
}
|
||||
|
||||
var (
|
||||
customDriverName = "MyDriver"
|
||||
latestSqlString = gtype.NewString() // For simplifying unit testing only.
|
||||
)
|
||||
|
||||
// New creates and returns a database object for mysql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
|
||||
return &MyDriver{
|
||||
&gdb.DriverMysql{
|
||||
Core: core,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HandleSqlBeforeCommit handles the sql before posts it to database.
|
||||
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
|
||||
func (d *MyDriver) HandleSqlBeforeCommit(link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
latestSqlString.Set(sql)
|
||||
return d.DriverMysql.HandleSqlBeforeCommit(link, sql, args)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// It here registers my custom driver in package initialization function "init".
|
||||
// You can later use this type in the database configuration.
|
||||
gdb.Register(customDriverName, &MyDriver{})
|
||||
}
|
||||
|
||||
func Test_Custom_Driver(t *testing.T) {
|
||||
gdb.AddConfigNode("driver-test", gdb.ConfigNode{
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Pass: "12345678",
|
||||
Name: "test",
|
||||
Type: customDriverName,
|
||||
Role: "master",
|
||||
Charset: "utf8",
|
||||
})
|
||||
gtest.Case(t, func() {
|
||||
gtest.Assert(latestSqlString.Val(), "")
|
||||
sqlString := "select 10000"
|
||||
value, err := g.DB("driver-test").GetValue(sqlString)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(value, 10000)
|
||||
gtest.Assert(latestSqlString.Val(), sqlString)
|
||||
})
|
||||
}
|
@ -681,6 +681,28 @@ func Test_Model_Struct(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Model_Struct_CustomType(t *testing.T) {
|
||||
table := createInitTable()
|
||||
defer dropTable(table)
|
||||
|
||||
type MyInt int
|
||||
|
||||
gtest.Case(t, func() {
|
||||
type User struct {
|
||||
Id MyInt
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
err := db.Table(table).Where("id=1").Struct(user)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(user.NickName, "name_1")
|
||||
gtest.Assert(user.CreateTime.String(), "2018-10-24 10:00:00")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Model_Structs(t *testing.T) {
|
||||
table := createInitTable()
|
||||
defer dropTable(table)
|
||||
|
@ -54,9 +54,9 @@ type PoolStats struct {
|
||||
}
|
||||
|
||||
const (
|
||||
gDEFAULT_POOL_IDLE_TIMEOUT = 60 * time.Second
|
||||
gDEFAULT_POOL_IDLE_TIMEOUT = 30 * time.Second
|
||||
gDEFAULT_POOL_CONN_TIMEOUT = 10 * time.Second
|
||||
gDEFAULT_POOL_MAX_LIFE_TIME = 60 * time.Second
|
||||
gDEFAULT_POOL_MAX_LIFE_TIME = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
@ -80,6 +80,7 @@ func New(config Config) *Redis {
|
||||
config: config,
|
||||
pool: pools.GetOrSetFuncLock(fmt.Sprintf("%v", config), func() interface{} {
|
||||
return &redis.Pool{
|
||||
Wait: true,
|
||||
IdleTimeout: config.IdleTimeout,
|
||||
MaxActive: config.MaxActive,
|
||||
MaxIdle: config.MaxIdle,
|
||||
|
@ -210,6 +210,11 @@ func Test_Error(t *testing.T) {
|
||||
func Test_Bool(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
redis := gredis.New(config)
|
||||
defer func() {
|
||||
redis.Do("DEL", "key-true")
|
||||
redis.Do("DEL", "key-false")
|
||||
}()
|
||||
|
||||
_, err := redis.Do("SET", "key-true", true)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
@ -230,6 +235,8 @@ func Test_Int(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
redis := gredis.New(config)
|
||||
key := guuid.New()
|
||||
defer redis.Do("DEL", key)
|
||||
|
||||
_, err := redis.Do("SET", key, 1)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
@ -243,6 +250,8 @@ func Test_HSet(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
redis := gredis.New(config)
|
||||
key := guuid.New()
|
||||
defer redis.Do("DEL", key)
|
||||
|
||||
_, err := redis.Do("HSET", key, "name", "john")
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
@ -251,3 +260,24 @@ func Test_HSet(t *testing.T) {
|
||||
gtest.Assert(r.Strings(), g.ArrayStr{"name", "john"})
|
||||
})
|
||||
}
|
||||
|
||||
func Test_HGetAll(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
var err error
|
||||
redis := gredis.New(config)
|
||||
key := guuid.New()
|
||||
defer redis.Do("DEL", key)
|
||||
|
||||
_, err = redis.Do("HSET", key, "id", "100")
|
||||
gtest.Assert(err, nil)
|
||||
_, err = redis.Do("HSET", key, "name", "john")
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
r, err := redis.DoVar("HGETALL", key)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(r.Map(), g.MapStrAny{
|
||||
"id": 100,
|
||||
"name": "john",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -10,6 +10,9 @@ package gdebug
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
@ -19,7 +22,6 @@ import (
|
||||
"github.com/gogf/gf/encoding/ghash"
|
||||
|
||||
"github.com/gogf/gf/crypto/gmd5"
|
||||
"github.com/gogf/gf/os/gfile"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -31,20 +33,30 @@ var (
|
||||
goRootForFilter = runtime.GOROOT() // goRootForFilter is used for stack filtering purpose.
|
||||
binaryVersion = "" // The version of current running binary(uint64 hex).
|
||||
binaryVersionMd5 = "" // The version of current running binary(MD5).
|
||||
selfPath = "" // Current running binary absolute path.
|
||||
)
|
||||
|
||||
func init() {
|
||||
if goRootForFilter != "" {
|
||||
goRootForFilter = strings.Replace(goRootForFilter, "\\", "/", -1)
|
||||
}
|
||||
// Initialize internal package variable: selfPath.
|
||||
selfPath, _ := exec.LookPath(os.Args[0])
|
||||
if selfPath != "" {
|
||||
selfPath, _ = filepath.Abs(selfPath)
|
||||
}
|
||||
if selfPath == "" {
|
||||
selfPath, _ = filepath.Abs(os.Args[0])
|
||||
}
|
||||
}
|
||||
|
||||
// BinVersion returns the version of current running binary.
|
||||
// It uses ghash.BKDRHash+BASE36 algorithm to calculate the unique version of the binary.
|
||||
func BinVersion() string {
|
||||
if binaryVersion == "" {
|
||||
binaryContent, _ := ioutil.ReadFile(selfPath)
|
||||
binaryVersion = strconv.FormatInt(
|
||||
int64(ghash.BKDRHash(gfile.GetBytes(gfile.SelfPath()))),
|
||||
int64(ghash.BKDRHash(binaryContent)),
|
||||
36,
|
||||
)
|
||||
}
|
||||
@ -55,7 +67,7 @@ func BinVersion() string {
|
||||
// It uses MD5 algorithm to calculate the unique version of the binary.
|
||||
func BinVersionMd5() string {
|
||||
if binaryVersionMd5 == "" {
|
||||
binaryVersionMd5, _ = gmd5.EncryptFile(gfile.SelfPath())
|
||||
binaryVersionMd5, _ = gmd5.EncryptFile(selfPath)
|
||||
}
|
||||
return binaryVersionMd5
|
||||
}
|
||||
|
39
encoding/gcompress/gcompress_z_unit_gzip_test.go
Normal file
39
encoding/gcompress/gcompress_z_unit_gzip_test.go
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package gcompress_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/encoding/gcompress"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_Gzip_UnGzip(t *testing.T) {
|
||||
src := "Hello World!!"
|
||||
|
||||
gzip := []byte{
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
0xf2, 0x48, 0xcd, 0xc9, 0xc9,
|
||||
0x57, 0x08, 0xcf, 0x2f, 0xca,
|
||||
0x49, 0x51, 0x54, 0x04, 0x04,
|
||||
0x00, 0x00, 0xff, 0xff, 0x9d,
|
||||
0x24, 0xa8, 0xd1, 0x0d, 0x00,
|
||||
0x00, 0x00,
|
||||
}
|
||||
|
||||
arr := []byte(src)
|
||||
data, _ := gcompress.Gzip(arr)
|
||||
gtest.Assert(data, gzip)
|
||||
|
||||
data, _ = gcompress.UnGzip(gzip)
|
||||
gtest.Assert(data, arr)
|
||||
|
||||
data, _ = gcompress.UnGzip(gzip[1:])
|
||||
gtest.Assert(data, nil)
|
||||
}
|
153
encoding/gcompress/gcompress_z_unit_zip_test.go
Normal file
153
encoding/gcompress/gcompress_z_unit_zip_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package gcompress_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/gogf/gf/debug/gdebug"
|
||||
"github.com/gogf/gf/encoding/gcompress"
|
||||
"github.com/gogf/gf/os/gfile"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_ZipPath(t *testing.T) {
|
||||
// file
|
||||
gtest.Case(t, func() {
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1", "1.txt")
|
||||
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
|
||||
|
||||
gtest.Assert(gfile.Exists(dstPath), false)
|
||||
err := gcompress.ZipPath(srcPath, dstPath)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(gfile.Exists(dstPath), true)
|
||||
defer gfile.Remove(dstPath)
|
||||
|
||||
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
err = gfile.Mkdir(tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
err = gcompress.UnZipFile(dstPath, tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
defer gfile.Remove(tempDirPath)
|
||||
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "1.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
|
||||
)
|
||||
})
|
||||
// directory
|
||||
gtest.Case(t, func() {
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
|
||||
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
|
||||
|
||||
pwd := gfile.Pwd()
|
||||
err := gfile.Chdir(srcPath)
|
||||
defer gfile.Chdir(pwd)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
gtest.Assert(gfile.Exists(dstPath), false)
|
||||
err = gcompress.ZipPath(srcPath, dstPath)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(gfile.Exists(dstPath), true)
|
||||
defer gfile.Remove(dstPath)
|
||||
|
||||
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
err = gfile.Mkdir(tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
err = gcompress.UnZipFile(dstPath, tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
defer gfile.Remove(tempDirPath)
|
||||
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "zip", "path1", "1.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
|
||||
)
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "zip", "path2", "2.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
|
||||
)
|
||||
})
|
||||
// multiple paths joined using char ','
|
||||
gtest.Case(t, func() {
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
|
||||
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1")
|
||||
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2")
|
||||
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
|
||||
|
||||
pwd := gfile.Pwd()
|
||||
err := gfile.Chdir(srcPath)
|
||||
defer gfile.Chdir(pwd)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
gtest.Assert(gfile.Exists(dstPath), false)
|
||||
err = gcompress.ZipPath(srcPath1+", "+srcPath2, dstPath)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(gfile.Exists(dstPath), true)
|
||||
defer gfile.Remove(dstPath)
|
||||
|
||||
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
err = gfile.Mkdir(tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
zipContent := gfile.GetBytes(dstPath)
|
||||
gtest.AssertGT(len(zipContent), 0)
|
||||
err = gcompress.UnZipContent(zipContent, tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
defer gfile.Remove(tempDirPath)
|
||||
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
|
||||
)
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ZipPathWriter(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
|
||||
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1")
|
||||
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2")
|
||||
|
||||
pwd := gfile.Pwd()
|
||||
err := gfile.Chdir(srcPath)
|
||||
defer gfile.Chdir(pwd)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
writer := bytes.NewBuffer(nil)
|
||||
gtest.Assert(writer.Len(), 0)
|
||||
err = gcompress.ZipPathWriter(srcPath1+", "+srcPath2, writer)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.AssertGT(writer.Len(), 0)
|
||||
|
||||
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
err = gfile.Mkdir(tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
|
||||
zipContent := writer.Bytes()
|
||||
gtest.AssertGT(len(zipContent), 0)
|
||||
err = gcompress.UnZipContent(zipContent, tempDirPath)
|
||||
gtest.Assert(err, nil)
|
||||
defer gfile.Remove(tempDirPath)
|
||||
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
|
||||
)
|
||||
gtest.Assert(
|
||||
gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")),
|
||||
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
|
||||
)
|
||||
})
|
||||
}
|
@ -13,7 +13,7 @@ import (
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func TestZlib(t *testing.T) {
|
||||
func Test_Zlib_UnZlib(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
src := "hello, world\n"
|
||||
dst := []byte{120, 156, 202, 72, 205, 201, 201, 215, 81, 40, 207, 47, 202, 73, 225, 2, 4, 0, 0, 255, 255, 33, 231, 4, 147}
|
||||
@ -31,30 +31,4 @@ func TestZlib(t *testing.T) {
|
||||
data, _ = gcompress.UnZlib(dst[1:])
|
||||
gtest.Assert(data, nil)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestGzip(t *testing.T) {
|
||||
src := "Hello World!!"
|
||||
|
||||
gzip := []byte{
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0xff,
|
||||
0xf2, 0x48, 0xcd, 0xc9, 0xc9,
|
||||
0x57, 0x08, 0xcf, 0x2f, 0xca,
|
||||
0x49, 0x51, 0x54, 0x04, 0x04,
|
||||
0x00, 0x00, 0xff, 0xff, 0x9d,
|
||||
0x24, 0xa8, 0xd1, 0x0d, 0x00,
|
||||
0x00, 0x00,
|
||||
}
|
||||
|
||||
arr := []byte(src)
|
||||
data, _ := gcompress.Gzip(arr)
|
||||
gtest.Assert(data, gzip)
|
||||
|
||||
data, _ = gcompress.UnGzip(gzip)
|
||||
gtest.Assert(data, arr)
|
||||
|
||||
data, _ = gcompress.UnGzip(gzip[1:])
|
||||
gtest.Assert(data, nil)
|
||||
}
|
@ -9,6 +9,7 @@ package gcompress
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -32,7 +33,15 @@ func ZipPath(paths, dest string, prefix ...string) error {
|
||||
return err
|
||||
}
|
||||
defer writer.Close()
|
||||
return ZipPathWriter(paths, writer, prefix...)
|
||||
zipWriter := zip.NewWriter(writer)
|
||||
defer zipWriter.Close()
|
||||
for _, path := range strings.Split(paths, ",") {
|
||||
path = strings.TrimSpace(path)
|
||||
if err := doZipPathWriter(path, gfile.RealPath(dest), zipWriter, prefix...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ZipPathWriter compresses <paths> to <writer> using zip compressing algorithm.
|
||||
@ -45,17 +54,21 @@ func ZipPathWriter(paths string, writer io.Writer, prefix ...string) error {
|
||||
defer zipWriter.Close()
|
||||
for _, path := range strings.Split(paths, ",") {
|
||||
path = strings.TrimSpace(path)
|
||||
if err := doZipPathWriter(path, zipWriter, prefix...); err != nil {
|
||||
if err := doZipPathWriter(path, "", zipWriter, prefix...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error {
|
||||
// doZipPathWriter compresses the file of given <path> and writes the content to <zipWriter>.
|
||||
// The parameter <exclude> specifies the exclusive file path that is not compressed to <zipWriter>,
|
||||
// commonly the destination zip file path.
|
||||
// The unnecessary parameter <prefix> indicates the path prefix for zip file.
|
||||
func doZipPathWriter(path string, exclude string, zipWriter *zip.Writer, prefix ...string) error {
|
||||
var err error
|
||||
var files []string
|
||||
realPath, err := gfile.Search(path)
|
||||
path, err = gfile.Search(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -80,7 +93,11 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error
|
||||
}
|
||||
headerPrefix = strings.Replace(headerPrefix, "//", "/", -1)
|
||||
for _, file := range files {
|
||||
err := zipFile(file, headerPrefix+gfile.Dir(file[len(realPath):]), zipWriter)
|
||||
if exclude == file {
|
||||
intlog.Printf(`exclude file path: %s`, file)
|
||||
continue
|
||||
}
|
||||
err := zipFile(file, headerPrefix+gfile.Dir(file[len(path):]), zipWriter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -101,10 +118,10 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error
|
||||
}
|
||||
|
||||
// UnZipFile decompresses <archive> to <dest> using zip compressing algorithm.
|
||||
// The parameter <path> specifies the unzipped path of <archive>,
|
||||
// The optional parameter <path> specifies the unzipped path of <archive>,
|
||||
// which can be used to specify part of the archive file to unzip.
|
||||
//
|
||||
// Note thate the parameter <dest> should be a directory.
|
||||
// Note that the parameter <dest> should be a directory.
|
||||
func UnZipFile(archive, dest string, path ...string) error {
|
||||
readerCloser, err := zip.OpenReader(archive)
|
||||
if err != nil {
|
||||
@ -118,7 +135,7 @@ func UnZipFile(archive, dest string, path ...string) error {
|
||||
// The parameter <path> specifies the unzipped path of <archive>,
|
||||
// which can be used to specify part of the archive file to unzip.
|
||||
//
|
||||
// Note thate the parameter <dest> should be a directory.
|
||||
// Note that the parameter <dest> should be a directory.
|
||||
func UnZipContent(data []byte, dest string, path ...string) error {
|
||||
reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
|
||||
if err != nil {
|
||||
@ -178,6 +195,8 @@ func unZipFileWithReader(reader *zip.Reader, dest string, path ...string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// zipFile compresses the file of given <path> and writes the content to <zw>.
|
||||
// The parameter <prefix> indicates the path prefix for zip file.
|
||||
func zipFile(path string, prefix string, zw *zip.Writer) error {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
1
encoding/gcompress/testdata/zip/path1/1.txt
vendored
Normal file
1
encoding/gcompress/testdata/zip/path1/1.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
This is a test file for zip compression purpose.
|
1
encoding/gcompress/testdata/zip/path2/2.txt
vendored
Normal file
1
encoding/gcompress/testdata/zip/path2/2.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
This is an another test file for zip compression purpose.
|
@ -136,7 +136,7 @@ func Test_GetMap(t *testing.T) {
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(j.GetMap("n"), nil)
|
||||
gtest.Assert(j.GetMap("m"), g.Map{"k": "v"})
|
||||
gtest.Assert(j.GetMap("a"), nil)
|
||||
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -95,7 +95,7 @@ func Test_GetVar(t *testing.T) {
|
||||
gtest.Assert(j.GetVar("m").Map(), g.Map{"k": "v"})
|
||||
gtest.Assert(j.GetVar("a").Interfaces(), g.Slice{1, 2, 3})
|
||||
gtest.Assert(j.GetVar("a").Slice(), g.Slice{1, 2, 3})
|
||||
gtest.Assert(j.GetVar("a").Array(), g.Slice{1, 2, 3})
|
||||
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
|
||||
})
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ func Test_GetMap(t *testing.T) {
|
||||
gtest.AssertNE(j, nil)
|
||||
gtest.Assert(j.GetMap("n"), nil)
|
||||
gtest.Assert(j.GetMap("m"), g.Map{"k": "v"})
|
||||
gtest.Assert(j.GetMap("a"), nil)
|
||||
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
|
||||
})
|
||||
}
|
||||
|
||||
|
3
go.mod
3
go.mod
@ -7,8 +7,8 @@ require (
|
||||
github.com/clbanning/mxj v1.8.4
|
||||
github.com/fatih/structs v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.4.7
|
||||
github.com/gf-third/mysql v1.4.2
|
||||
github.com/gf-third/yaml v1.0.1
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/gomodule/redigo v2.0.0+incompatible
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/gorilla/websocket v1.4.1
|
||||
@ -17,5 +17,4 @@ require (
|
||||
github.com/olekukonko/tablewriter v0.0.1
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e // indirect
|
||||
golang.org/x/text v0.3.2
|
||||
google.golang.org/appengine v1.6.5 // indirect
|
||||
)
|
||||
|
@ -4,39 +4,38 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package rwmutex provides switch of concurrent safe feature for sync.RWMutex.
|
||||
// Package rwmutex provides switch of concurrent safety feature for sync.RWMutex.
|
||||
package rwmutex
|
||||
|
||||
import "sync"
|
||||
|
||||
// RWMutex is a sync.RWMutex with a switch of concurrent safe feature.
|
||||
// If its attribute *sync.RWMutex is not nil, it means it's in concurrent safety usage.
|
||||
// Its attribute *sync.RWMutex is nil in default, which makes this struct mush lightweight.
|
||||
type RWMutex struct {
|
||||
sync.RWMutex
|
||||
safe bool
|
||||
*sync.RWMutex
|
||||
}
|
||||
|
||||
// New creates and returns a new *RWMutex.
|
||||
// The parameter <safe> is used to specify whether using this mutex in concurrent-safety,
|
||||
// The parameter <safe> is used to specify whether using this mutex in concurrent safety,
|
||||
// which is false in default.
|
||||
func New(safe ...bool) *RWMutex {
|
||||
mu := new(RWMutex)
|
||||
if len(safe) > 0 {
|
||||
mu.safe = safe[0]
|
||||
} else {
|
||||
mu.safe = false
|
||||
if len(safe) > 0 && safe[0] {
|
||||
mu.RWMutex = new(sync.RWMutex)
|
||||
}
|
||||
return mu
|
||||
}
|
||||
|
||||
// IsSafe checks and returns whether current mutex is in concurrent-safe usage.
|
||||
func (mu *RWMutex) IsSafe() bool {
|
||||
return mu.safe
|
||||
return mu.RWMutex != nil
|
||||
}
|
||||
|
||||
// Lock locks mutex for writing.
|
||||
// It does nothing if it is not in concurrent-safe usage.
|
||||
func (mu *RWMutex) Lock() {
|
||||
if mu.safe {
|
||||
if mu.RWMutex != nil {
|
||||
mu.RWMutex.Lock()
|
||||
}
|
||||
}
|
||||
@ -44,7 +43,7 @@ func (mu *RWMutex) Lock() {
|
||||
// Unlock unlocks mutex for writing.
|
||||
// It does nothing if it is not in concurrent-safe usage.
|
||||
func (mu *RWMutex) Unlock() {
|
||||
if mu.safe {
|
||||
if mu.RWMutex != nil {
|
||||
mu.RWMutex.Unlock()
|
||||
}
|
||||
}
|
||||
@ -52,7 +51,7 @@ func (mu *RWMutex) Unlock() {
|
||||
// RLock locks mutex for reading.
|
||||
// It does nothing if it is not in concurrent-safe usage.
|
||||
func (mu *RWMutex) RLock() {
|
||||
if mu.safe {
|
||||
if mu.RWMutex != nil {
|
||||
mu.RWMutex.RLock()
|
||||
}
|
||||
}
|
||||
@ -60,7 +59,7 @@ func (mu *RWMutex) RLock() {
|
||||
// RUnlock unlocks mutex for reading.
|
||||
// It does nothing if it is not in concurrent-safe usage.
|
||||
func (mu *RWMutex) RUnlock() {
|
||||
if mu.safe {
|
||||
if mu.RWMutex != nil {
|
||||
mu.RWMutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
@ -50,9 +50,7 @@ func niceCallFunc(f func()) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
switch err {
|
||||
case gEXCEPTION_EXIT:
|
||||
fallthrough
|
||||
case gEXCEPTION_EXIT_ALL:
|
||||
case gEXCEPTION_EXIT, gEXCEPTION_EXIT_ALL:
|
||||
return
|
||||
default:
|
||||
panic(err)
|
||||
|
@ -8,6 +8,7 @@ package ghttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gogf/gf/container/gmap"
|
||||
"github.com/gogf/gf/os/gres"
|
||||
"github.com/gogf/gf/os/gview"
|
||||
"net/http"
|
||||
@ -31,6 +32,7 @@ type Request struct {
|
||||
LeaveTime int64 // Request ending time in microseconds.
|
||||
Middleware *Middleware // The middleware manager.
|
||||
StaticFile *StaticFile // Static file object when static file serving.
|
||||
Context *gmap.StrAnyMap // Custom context map for internal usage purpose.
|
||||
handlers []*handlerParsedItem // All matched handlers containing handler, hook and middleware for this request .
|
||||
hasHookHandler bool // A bool marking whether there's hook handler in the handlers for performance purpose.
|
||||
hasServeHandler bool // A bool marking whether there's serving handler in the handlers for performance purpose.
|
||||
@ -66,6 +68,7 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request {
|
||||
Request: r,
|
||||
Response: newResponse(s, w),
|
||||
EnterTime: gtime.TimestampMilli(),
|
||||
Context: gmap.NewStrAnyMap(),
|
||||
}
|
||||
request.Cookie = GetCookie(request)
|
||||
request.Session = s.sessionManager.New(request.GetSessionId())
|
||||
|
@ -123,6 +123,7 @@ func (m *Middleware) Next() {
|
||||
}, func(exception interface{}) {
|
||||
m.request.error = gerror.Newf("%v", exception)
|
||||
m.request.Response.WriteStatus(http.StatusInternalServerError, exception)
|
||||
loop = false
|
||||
})
|
||||
}
|
||||
// Check the http status code after all handler and middleware done.
|
||||
|
@ -9,6 +9,7 @@ package ghttp
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/container/gvar"
|
||||
"github.com/gogf/gf/encoding/gjson"
|
||||
"github.com/gogf/gf/encoding/gurl"
|
||||
@ -302,5 +303,19 @@ func (r *Request) GetMultipartFiles(name string) []*multipart.FileHeader {
|
||||
if v := form.File[name+"[]"]; len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
// Support "name[0]","name[1]","name[2]", etc. as array parameter.
|
||||
key := ""
|
||||
files := make([]*multipart.FileHeader, 0)
|
||||
for i := 0; ; i++ {
|
||||
key = fmt.Sprintf(`%s[%d]`, name, i)
|
||||
if v := form.File[key]; len(v) > 0 {
|
||||
files = append(files, v[0])
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(files) > 0 {
|
||||
return files
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -8,12 +8,12 @@ package ghttp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"github.com/gogf/gf/os/gfile"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
"github.com/gogf/gf/util/grand"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@ -26,66 +26,67 @@ type UploadFile struct {
|
||||
// UploadFiles is array type for *UploadFile.
|
||||
type UploadFiles []*UploadFile
|
||||
|
||||
// Save saves the single uploading file to specified path.
|
||||
// The parameter path can be either a directory or a file path. If <path> is a directory,
|
||||
// it saves the uploading file to the directory using its original name. If <path> is a
|
||||
// file path, it saves the uploading file to the file path.
|
||||
// Save saves the single uploading file to directory path and returns the saved file name.
|
||||
//
|
||||
// The parameter <dirPath> should be a directory path or it returns error.
|
||||
//
|
||||
// The parameter <randomlyRename> specifies whether randomly renames the file name, which
|
||||
// make sense if the <path> is a directory.
|
||||
//
|
||||
// Note that it will overwrite the target file if there's already a same name file exist.
|
||||
func (f *UploadFile) Save(path string, randomlyRename ...bool) error {
|
||||
func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename string, err error) {
|
||||
if f == nil {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
if !gfile.Exists(dirPath) {
|
||||
if err = gfile.Mkdir(dirPath); err != nil {
|
||||
return
|
||||
}
|
||||
} else if !gfile.IsDir(dirPath) {
|
||||
return "", errors.New(`parameter "dirPath" should be a directory path`)
|
||||
}
|
||||
|
||||
file, err := f.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var newFile *os.File
|
||||
if gfile.IsDir(path) {
|
||||
filename := gfile.Basename(f.Filename)
|
||||
if len(randomlyRename) > 0 && randomlyRename[0] {
|
||||
filename = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6))
|
||||
filename = filename + gfile.Ext(f.Filename)
|
||||
}
|
||||
newFile, err = gfile.Create(gfile.Join(path, filename))
|
||||
} else {
|
||||
newFile, err = gfile.Create(path)
|
||||
name := gfile.Basename(f.Filename)
|
||||
if len(randomlyRename) > 0 && randomlyRename[0] {
|
||||
name = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6))
|
||||
name = name + gfile.Ext(f.Filename)
|
||||
}
|
||||
filePath := gfile.Join(dirPath, name)
|
||||
newFile, err := gfile.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
defer newFile.Close()
|
||||
|
||||
intlog.Printf(`save upload file: %s`, filePath)
|
||||
if _, err := io.Copy(newFile, file); err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
return nil
|
||||
return gfile.Basename(filePath), nil
|
||||
}
|
||||
|
||||
// Save saves all uploading files to specified directory path.
|
||||
// Save saves all uploading files to specified directory path and returns the saved file names.
|
||||
//
|
||||
// The parameter <dirPath> should be a directory path or it returns error.
|
||||
//
|
||||
// The parameter <randomlyRename> specifies whether randomly renames all the file names.
|
||||
func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) error {
|
||||
func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames []string, err error) {
|
||||
if len(fs) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
if !gfile.IsDir(dirPath) {
|
||||
return errors.New(`parameter "dirPath" should be a directory path`)
|
||||
}
|
||||
var err error
|
||||
for _, f := range fs {
|
||||
if err = f.Save(dirPath, randomlyRename...); err != nil {
|
||||
return err
|
||||
if filename, err := f.Save(dirPath, randomlyRename...); err != nil {
|
||||
return filenames, err
|
||||
} else {
|
||||
filenames = append(filenames, filename)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// GetUploadFile retrieves and returns the uploading file with specified form name.
|
||||
|
64
net/ghttp/ghttp_request_param_page.go
Normal file
64
net/ghttp/ghttp_request_param_page.go
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
)
|
||||
|
||||
// GetPage creates and returns the pagination object for given <totalSize> and <pageSize>.
|
||||
// NOTE THAT the page parameter name from client is constantly defined as gpage.PAGE_NAME
|
||||
// for simplification and convenience.
|
||||
func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page {
|
||||
// It must has Router object attribute.
|
||||
if r.Router == nil {
|
||||
panic("Router object not found")
|
||||
}
|
||||
url := *r.URL
|
||||
urlTemplate := url.Path
|
||||
uriHasPageName := false
|
||||
// Check the page variable in the URI.
|
||||
if len(r.Router.RegNames) > 0 {
|
||||
for _, name := range r.Router.RegNames {
|
||||
if name == gpage.PAGE_NAME {
|
||||
uriHasPageName = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if uriHasPageName {
|
||||
if match, err := gregex.MatchString(r.Router.RegRule, url.Path); err == nil && len(match) > 0 {
|
||||
if len(match) > len(r.Router.RegNames) {
|
||||
urlTemplate = r.Router.Uri
|
||||
for i, name := range r.Router.RegNames {
|
||||
rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name)
|
||||
if name == gpage.PAGE_NAME {
|
||||
urlTemplate, _ = gregex.ReplaceString(rule, gpage.PAGE_PLACE_HOLDER, urlTemplate)
|
||||
} else {
|
||||
urlTemplate, _ = gregex.ReplaceString(rule, match[i+1], urlTemplate)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check the page variable in the query string.
|
||||
if !uriHasPageName {
|
||||
values := url.Query()
|
||||
values.Set(gpage.PAGE_NAME, gpage.PAGE_PLACE_HOLDER)
|
||||
url.RawQuery = values.Encode()
|
||||
// Replace the encodes "{.page}" to "{.page}".
|
||||
url.RawQuery = gstr.Replace(url.RawQuery, "%7B.page%7D", "{.page}")
|
||||
}
|
||||
if url.RawQuery != "" {
|
||||
urlTemplate += "?" + url.RawQuery
|
||||
}
|
||||
|
||||
return gpage.New(totalSize, pageSize, r.GetInt(gpage.PAGE_NAME), urlTemplate)
|
||||
}
|
@ -8,11 +8,10 @@
|
||||
package ghttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"github.com/gogf/gf/util/gconv"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// CORSOptions is the options for CORS feature.
|
||||
@ -27,6 +26,20 @@ type CORSOptions struct {
|
||||
AllowHeaders string // Access-Control-Allow-Headers
|
||||
}
|
||||
|
||||
var (
|
||||
// defaultAllowHeaders is the default allowed headers for CORS.
|
||||
// It's defined another map for better header key searching performance.
|
||||
defaultAllowHeaders = "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With"
|
||||
defaultAllowHeadersMap = make(map[string]struct{})
|
||||
)
|
||||
|
||||
func init() {
|
||||
array := gstr.SplitAndTrim(defaultAllowHeaders, ",")
|
||||
for _, header := range array {
|
||||
defaultAllowHeadersMap[header] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCORSOptions returns the default CORS options,
|
||||
// which allows any cross-domain request.
|
||||
func (r *Response) DefaultCORSOptions() CORSOptions {
|
||||
@ -34,9 +47,19 @@ func (r *Response) DefaultCORSOptions() CORSOptions {
|
||||
AllowOrigin: "*",
|
||||
AllowMethods: HTTP_METHODS,
|
||||
AllowCredentials: "true",
|
||||
AllowHeaders: "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With",
|
||||
AllowHeaders: defaultAllowHeaders,
|
||||
MaxAge: 3628800,
|
||||
}
|
||||
// Allow all client's custom headers in default.
|
||||
if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
||||
array := gstr.SplitAndTrim(headers, ",")
|
||||
for _, header := range array {
|
||||
if _, ok := defaultAllowHeadersMap[header]; !ok {
|
||||
options.AllowHeaders += header + ","
|
||||
}
|
||||
}
|
||||
}
|
||||
// Allow all anywhere origin in default.
|
||||
if origin := r.Request.Header.Get("Origin"); origin != "" {
|
||||
options.AllowOrigin = origin
|
||||
} else if referer := r.Request.Referer(); referer != "" {
|
||||
@ -72,8 +95,26 @@ func (r *Response) CORS(options CORSOptions) {
|
||||
}
|
||||
// No continue service handling if it's OPTIONS request.
|
||||
if gstr.Equal(r.Request.Method, "OPTIONS") {
|
||||
// Request method handler searching.
|
||||
// It here simply uses Server.routesMap attribute enhancing the searching performance.
|
||||
if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" {
|
||||
routerKey := ""
|
||||
for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} {
|
||||
for _, v := range []string{gDEFAULT_METHOD, method} {
|
||||
routerKey = r.Server.routerMapKey("", v, r.Request.URL.Path, domain)
|
||||
if r.Server.routesMap[routerKey] != nil {
|
||||
if r.Status == 0 {
|
||||
r.Status = http.StatusOK
|
||||
}
|
||||
// No continue serving.
|
||||
r.Request.ExitAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Cannot find the request serving handler, it then responses 404.
|
||||
if r.Status == 0 {
|
||||
r.Status = http.StatusOK
|
||||
r.Status = http.StatusNotFound
|
||||
}
|
||||
r.Request.ExitAll()
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/debug/gdebug"
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
@ -78,7 +79,7 @@ type (
|
||||
// handlerItem is the registered handler for route handling,
|
||||
// including middleware and hook functions.
|
||||
handlerItem struct {
|
||||
itemId int // Unique ID mark.
|
||||
itemId int // Unique handler item id mark.
|
||||
itemName string // Handler name, which is automatically retrieved from runtime stack when registered.
|
||||
itemType int // Handler type: object/handler/controller/middleware/hook.
|
||||
itemFunc HandlerFunc // Handler address.
|
||||
@ -143,7 +144,7 @@ var (
|
||||
// it is used for quick HTTP method searching using map.
|
||||
methodsMap = make(map[string]struct{})
|
||||
|
||||
// serverMapping stores more than one server instances.
|
||||
// serverMapping stores more than one server instances for current process.
|
||||
// The key is the name of the server, and the value is its instance.
|
||||
serverMapping = gmap.NewStrAnyMap(true)
|
||||
|
||||
@ -444,14 +445,20 @@ func (s *Server) GetRouterArray() []RouterItem {
|
||||
}
|
||||
|
||||
// Run starts server listening in blocking way.
|
||||
// It's commonly used for single server situation.
|
||||
func (s *Server) Run() {
|
||||
if err := s.Start(); err != nil {
|
||||
s.Logger().Fatal(err)
|
||||
}
|
||||
|
||||
// Blocking using channel.
|
||||
<-s.closeChan
|
||||
|
||||
// Remove plugins.
|
||||
if len(s.plugins) > 0 {
|
||||
for _, p := range s.plugins {
|
||||
intlog.Printf(`remove plugin: %s`, p.Name())
|
||||
p.Remove()
|
||||
}
|
||||
}
|
||||
s.Logger().Printf("[ghttp] %d: all servers shutdown", gproc.Pid())
|
||||
}
|
||||
|
||||
@ -459,7 +466,17 @@ func (s *Server) Run() {
|
||||
// It's commonly used in multiple servers situation.
|
||||
func Wait() {
|
||||
<-allDoneChan
|
||||
|
||||
// Remove plugins.
|
||||
serverMapping.Iterator(func(k string, v interface{}) bool {
|
||||
s := v.(*Server)
|
||||
if len(s.plugins) > 0 {
|
||||
for _, p := range s.plugins {
|
||||
intlog.Printf(`remove plugin: %s`, p.Name())
|
||||
p.Remove()
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
glog.Printf("[ghttp] %d: all servers shutdown", gproc.Pid())
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ package ghttp
|
||||
type Plugin interface {
|
||||
Name() string // Name returns the name of the plugin.
|
||||
Author() string // Author returns the author of the plugin.
|
||||
Version() string // Version returns the version of the plugin.
|
||||
Version() string // Version returns the version of the plugin, like "v1.0.0".
|
||||
Description() string // Description returns the description of the plugin.
|
||||
Install(s *Server) error // Install installs the plugin before server starts.
|
||||
Remove() error // Remove removes the plugin.
|
||||
|
@ -24,11 +24,18 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// 用于服务函数的ID生成变量
|
||||
// handlerIdGenerator is handler item id generator.
|
||||
handlerIdGenerator = gtype.NewInt()
|
||||
)
|
||||
|
||||
// 解析pattern
|
||||
// routerMapKey creates and returns an unique router key for given parameters.
|
||||
// This key is used for Server.routerMap attribute, which is mainly for checks for
|
||||
// repeated router registering.
|
||||
func (s *Server) routerMapKey(hook, method, path, domain string) string {
|
||||
return hook + "%" + s.serveHandlerKey(method, path, domain)
|
||||
}
|
||||
|
||||
// parsePattern parses the given pattern to domain, method and path variable.
|
||||
func (s *Server) parsePattern(pattern string) (domain, method, path string, err error) {
|
||||
path = strings.TrimSpace(pattern)
|
||||
domain = gDEFAULT_DOMAIN
|
||||
@ -48,16 +55,16 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err
|
||||
if path == "" {
|
||||
err = errors.New("invalid pattern: URI should not be empty")
|
||||
}
|
||||
// 去掉末尾的"/"符号,与路由匹配时处理一致
|
||||
if path != "/" {
|
||||
path = strings.TrimRight(path, "/")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 路由注册处理方法。
|
||||
// 非叶节点为哈希表检索节点,按照URI注册的层级进行高效检索,直至到叶子链表节点;
|
||||
// 叶子节点是链表,按照优先级进行排序,优先级高的排前面,按照遍历检索,按照哈希表层级检索后的叶子链表数据量不会很大,所以效率比较高;
|
||||
// setHandler creates router item with given handler and pattern and registers the handler to the router tree.
|
||||
// The router tree can be treated as a multilayer hash table, please refer to the comment in following codes.
|
||||
// This function is called during server starts up, which cares little about the performance. What really cares
|
||||
// is the well designed router storage structure for router searching when the request is under serving.
|
||||
func (s *Server) setHandler(pattern string, handler *handlerItem) {
|
||||
handler.itemId = handlerIdGenerator.Add(1)
|
||||
domain, method, uri, err := s.parsePattern(pattern)
|
||||
@ -69,30 +76,32 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
|
||||
s.Logger().Fatal("invalid pattern:", pattern, "URI should lead with '/'")
|
||||
return
|
||||
}
|
||||
// 注册地址记录及重复注册判断
|
||||
regKey := s.handlerKey(handler.hookName, method, uri, domain)
|
||||
|
||||
// Repeated router checks, this feature can be disabled by server configuration.
|
||||
routerKey := s.routerMapKey(handler.hookName, method, uri, domain)
|
||||
if !s.config.RouteOverWrite {
|
||||
switch handler.itemType {
|
||||
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
|
||||
if item, ok := s.routesMap[regKey]; ok {
|
||||
if item, ok := s.routesMap[routerKey]; ok {
|
||||
s.Logger().Fatalf(`duplicated route registry "%s", already registered at %s`, pattern, item[0].file)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// 注册的路由信息对象
|
||||
// Create a new router by given parameter.
|
||||
handler.router = &Router{
|
||||
Uri: uri,
|
||||
Domain: domain,
|
||||
Method: method,
|
||||
Method: strings.ToUpper(method),
|
||||
Priority: strings.Count(uri[1:], "/"),
|
||||
}
|
||||
handler.router.RegRule, handler.router.RegNames = s.patternToRegRule(uri)
|
||||
handler.router.RegRule, handler.router.RegNames = s.patternToRegular(uri)
|
||||
|
||||
if _, ok := s.serveTree[domain]; !ok {
|
||||
s.serveTree[domain] = make(map[string]interface{})
|
||||
}
|
||||
// 当前节点的规则链表
|
||||
// List array, very important for router registering.
|
||||
// There may be multiple lists adding into this array when searching from root to leaf.
|
||||
lists := make([]*glist.List, 0)
|
||||
array := ([]string)(nil)
|
||||
if strings.EqualFold("/", uri) {
|
||||
@ -100,43 +109,58 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
|
||||
} else {
|
||||
array = strings.Split(uri[1:], "/")
|
||||
}
|
||||
// 键名"*fuzz"代表当前节点为模糊匹配节点,该节点也会有一个*list链表;
|
||||
// 键名"*list"代表链表,叶子节点和模糊匹配节点都有该属性,优先级越高越排前;
|
||||
// Multilayer hash table:
|
||||
// 1. Each node of the table is separated by URI path which is split by char '/'.
|
||||
// 2. The key "*fuzz" specifies this node is a fuzzy node, which has no certain name.
|
||||
// 3. The key "*list" is the list item of the node, MOST OF THE NODES HAVE THIS ITEM,
|
||||
// especially the fuzzy node. NOTE THAT the fuzzy node must have the "*list" item,
|
||||
// and the leaf node also has "*list" item. If the node is not a fuzzy node either
|
||||
// a leaf, it neither has "*list" item.
|
||||
// 2. The "*list" item is a list containing registered router items ordered by their
|
||||
// priorities from high to low.
|
||||
// 3. There may be repeated router items in the router lists. The lists' priorities
|
||||
// from root to leaf are from low to high.
|
||||
p := s.serveTree[domain]
|
||||
for k, v := range array {
|
||||
if len(v) == 0 {
|
||||
for i, part := range array {
|
||||
// Ignore empty URI part, like: /user//index
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
// 判断是否模糊匹配规则
|
||||
if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, v) {
|
||||
v = "*fuzz"
|
||||
// 由于是模糊规则,因此这里会有一个*list,用以将后续的路由规则加进来,
|
||||
// 检索会从叶子节点的链表往根节点按照优先级进行检索
|
||||
// Check if it's a fuzzy node.
|
||||
if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, part) {
|
||||
part = "*fuzz"
|
||||
// If it's a fuzzy node, it creates a "*list" item - which is a list - in the hash map.
|
||||
// All the sub router items from this fuzzy node will also be added to its "*list" item.
|
||||
if v, ok := p.(map[string]interface{})["*list"]; !ok {
|
||||
p.(map[string]interface{})["*list"] = glist.New()
|
||||
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
|
||||
newListForFuzzy := glist.New()
|
||||
p.(map[string]interface{})["*list"] = newListForFuzzy
|
||||
lists = append(lists, newListForFuzzy)
|
||||
} else {
|
||||
lists = append(lists, v.(*glist.List))
|
||||
}
|
||||
}
|
||||
// 属性层级数据写入
|
||||
if _, ok := p.(map[string]interface{})[v]; !ok {
|
||||
p.(map[string]interface{})[v] = make(map[string]interface{})
|
||||
// Make a new bucket for current node.
|
||||
if _, ok := p.(map[string]interface{})[part]; !ok {
|
||||
p.(map[string]interface{})[part] = make(map[string]interface{})
|
||||
}
|
||||
p = p.(map[string]interface{})[v]
|
||||
// 到达叶子节点,往list中增加匹配规则(条件 v != "*fuzz" 是因为模糊节点的话在前面已经添加了*list链表)
|
||||
if k == len(array)-1 && v != "*fuzz" {
|
||||
// Loop to next bucket.
|
||||
p = p.(map[string]interface{})[part]
|
||||
// The leaf is a hash map and must have an item named "*list", which contains the router item.
|
||||
// The leaf can be furthermore extended by adding more ket-value pairs into its map.
|
||||
// Note that the `v != "*fuzz"` comparison is required as the list might be added in the former
|
||||
// fuzzy checks.
|
||||
if i == len(array)-1 && part != "*fuzz" {
|
||||
if v, ok := p.(map[string]interface{})["*list"]; !ok {
|
||||
p.(map[string]interface{})["*list"] = glist.New()
|
||||
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
|
||||
leafList := glist.New()
|
||||
p.(map[string]interface{})["*list"] = leafList
|
||||
lists = append(lists, leafList)
|
||||
} else {
|
||||
lists = append(lists, v.(*glist.List))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 上面循环后得到的lists是该路由规则一路匹配下来相关的模糊匹配链表(注意不是这棵树所有的链表)。
|
||||
// 下面从头开始遍历每个节点的模糊匹配链表,将该路由项插入进去(按照优先级高的放在lists链表的前面)
|
||||
// It iterates the list array of <lists>, compares priorities and inserts the new router item in
|
||||
// the proper position of each list. The priority of the list is ordered from high to low.
|
||||
item := (*handlerItem)(nil)
|
||||
for _, l := range lists {
|
||||
pushed := false
|
||||
@ -157,8 +181,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
|
||||
}
|
||||
}
|
||||
// Initialize the route map item.
|
||||
if _, ok := s.routesMap[regKey]; !ok {
|
||||
s.routesMap[regKey] = make([]registeredRouteItem, 0)
|
||||
if _, ok := s.routesMap[routerKey]; !ok {
|
||||
s.routesMap[routerKey] = make([]registeredRouteItem, 0)
|
||||
}
|
||||
_, file, line := gdebug.CallerWithFilter(gFILTER_KEY)
|
||||
routeItem := registeredRouteItem{
|
||||
@ -168,35 +192,39 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
|
||||
switch handler.itemType {
|
||||
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
|
||||
// Overwrite the route.
|
||||
s.routesMap[regKey] = []registeredRouteItem{routeItem}
|
||||
s.routesMap[routerKey] = []registeredRouteItem{routeItem}
|
||||
default:
|
||||
// Append the route.
|
||||
s.routesMap[regKey] = append(s.routesMap[regKey], routeItem)
|
||||
s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem)
|
||||
}
|
||||
}
|
||||
|
||||
// 对比两个handlerItem的优先级,需要非常注意的是,注意新老对比项的参数先后顺序。
|
||||
// 返回值true表示newItem优先级比oldItem高,会被添加链表中oldRouter的前面;否则后面。
|
||||
// 优先级比较规则:
|
||||
// 1、中间件优先级最高,按照添加顺序优先级执行;
|
||||
// 2、其他路由注册类型,层级越深优先级越高(对比/数量);
|
||||
// 3、模糊规则优先级:{xxx} > :xxx > *xxx;
|
||||
// compareRouterPriority compares the priority between <newItem> and <oldItem>. It returns true
|
||||
// if <newItem>'s priority is higher than <oldItem>, else it returns false. The higher priority
|
||||
// item will be insert into the router list before the other one.
|
||||
//
|
||||
// Comparison rules:
|
||||
// 1. The middleware has the most high priority.
|
||||
// 2. URI: The deeper the higher (simply check the count of char '/' in the URI).
|
||||
// 3. Route type: {xxx} > :xxx > *xxx.
|
||||
func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerItem) bool {
|
||||
// 中间件优先级最高,按照添加顺序优先级执行
|
||||
// If they're all type of middleware, the priority is according their registered sequence.
|
||||
if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType == gHANDLER_TYPE_MIDDLEWARE {
|
||||
return false
|
||||
}
|
||||
// The middleware has the most high priority.
|
||||
if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType != gHANDLER_TYPE_MIDDLEWARE {
|
||||
return true
|
||||
}
|
||||
// 优先比较层级,层级越深优先级越高
|
||||
// URI: The deeper the higher (simply check the count of char '/' in the URI).
|
||||
if newItem.router.Priority > oldItem.router.Priority {
|
||||
return true
|
||||
}
|
||||
if newItem.router.Priority < oldItem.router.Priority {
|
||||
return false
|
||||
}
|
||||
// 精准匹配比模糊匹配规则优先级高,例如:/name/act 比 /{name}/:act 优先级高
|
||||
// Route type: {xxx} > :xxx > *xxx.
|
||||
// Eg: /name/act > /{name}/:act
|
||||
var fuzzyCountFieldNew, fuzzyCountFieldOld int
|
||||
var fuzzyCountNameNew, fuzzyCountNameOld int
|
||||
var fuzzyCountAnyNew, fuzzyCountAnyOld int
|
||||
@ -230,16 +258,16 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
|
||||
return false
|
||||
}
|
||||
|
||||
/** 如果模糊规则数量相等,那么执行分别的数量判断 **/
|
||||
// If the counts of their fuzzy rules equal.
|
||||
|
||||
// 例如:/name/{act} 比 /name/:act 优先级高
|
||||
// Eg: /name/{act} > /name/:act
|
||||
if fuzzyCountFieldNew > fuzzyCountFieldOld {
|
||||
return true
|
||||
}
|
||||
if fuzzyCountFieldNew < fuzzyCountFieldOld {
|
||||
return false
|
||||
}
|
||||
// 例如: /name/:act 比 /name/*act 优先级高
|
||||
// Eg: /name/:act > /name/*act
|
||||
if fuzzyCountNameNew > fuzzyCountNameOld {
|
||||
return true
|
||||
}
|
||||
@ -247,9 +275,10 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
|
||||
return false
|
||||
}
|
||||
|
||||
/** 比较路由规则长度,越长的规则优先级越高,模糊/命名规则不算长度 **/
|
||||
// It then compares the length of their URI,
|
||||
// but the fuzzy and named parts of the URI are not calculated to the result.
|
||||
|
||||
// 例如:/admin-goods-{page} 比 /admin-{page} 优先级高
|
||||
// Eg: /admin-goods-{page} > /admin-{page}
|
||||
var uriNew, uriOld string
|
||||
uriNew, _ = gregex.ReplaceString(`\{[^/]+\}`, "", newItem.router.Uri)
|
||||
uriNew, _ = gregex.ReplaceString(`:[^/]+`, "", uriNew)
|
||||
@ -264,9 +293,8 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
|
||||
return false
|
||||
}
|
||||
|
||||
/* 模糊规则数量相等,后续不用再判断*规则的数量比较了 */
|
||||
|
||||
// 比较HTTP METHOD,更精准的优先级更高
|
||||
// It then compares the accuracy of their http method,
|
||||
// the more accurate the more priority.
|
||||
if newItem.router.Method != gDEFAULT_METHOD {
|
||||
return true
|
||||
}
|
||||
@ -274,23 +302,25 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果是服务路由,那么新的规则比旧的规则优先级高(路由覆盖)
|
||||
// If they have different router type,
|
||||
// the new router item has more priority than the other one.
|
||||
if newItem.itemType == gHANDLER_TYPE_HANDLER ||
|
||||
newItem.itemType == gHANDLER_TYPE_OBJECT ||
|
||||
newItem.itemType == gHANDLER_TYPE_CONTROLLER {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果是其他路由(HOOK/中间件),那么新的规则比旧的规则优先级低,使得注册相同路由则顺序执行
|
||||
// Other situations, like HOOK items,
|
||||
// the old router item has more priority than the other one.
|
||||
return false
|
||||
}
|
||||
|
||||
// 将pattern(不带method和domain)解析成正则表达式匹配以及对应的query字符串
|
||||
func (s *Server) patternToRegRule(rule string) (regrule string, names []string) {
|
||||
// patternToRegular converts route rule to according regular expression.
|
||||
func (s *Server) patternToRegular(rule string) (regular string, names []string) {
|
||||
if len(rule) < 2 {
|
||||
return rule, nil
|
||||
}
|
||||
regrule = "^"
|
||||
regular = "^"
|
||||
array := strings.Split(rule[1:], "/")
|
||||
for _, v := range array {
|
||||
if len(v) == 0 {
|
||||
@ -299,20 +329,20 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
|
||||
switch v[0] {
|
||||
case ':':
|
||||
if len(v) > 1 {
|
||||
regrule += `/([^/]+)`
|
||||
regular += `/([^/]+)`
|
||||
names = append(names, v[1:])
|
||||
} else {
|
||||
regrule += `/[^/]+`
|
||||
regular += `/[^/]+`
|
||||
}
|
||||
case '*':
|
||||
if len(v) > 1 {
|
||||
regrule += `/{0,1}(.*)`
|
||||
regular += `/{0,1}(.*)`
|
||||
names = append(names, v[1:])
|
||||
} else {
|
||||
regrule += `/{0,1}.*`
|
||||
regular += `/{0,1}.*`
|
||||
}
|
||||
default:
|
||||
// 特殊字符替换
|
||||
// Special chars replacement.
|
||||
v = gstr.ReplaceByMap(v, map[string]string{
|
||||
`.`: `\.`,
|
||||
`+`: `\+`,
|
||||
@ -323,12 +353,12 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
|
||||
return `([^/]+)`
|
||||
})
|
||||
if strings.EqualFold(s, v) {
|
||||
regrule += "/" + v
|
||||
regular += "/" + v
|
||||
} else {
|
||||
regrule += "/" + s
|
||||
regular += "/" + s
|
||||
}
|
||||
}
|
||||
}
|
||||
regrule += `$`
|
||||
regular += `$`
|
||||
return
|
||||
}
|
||||
|
@ -66,8 +66,3 @@ func (s *Server) niceCallHookHandler(f HandlerFunc, r *Request) (err interface{}
|
||||
f(r)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成hook key,如果是hook key,那么使用'%'符号分隔
|
||||
func (s *Server) handlerKey(hook, method, path, domain string) string {
|
||||
return hook + "%" + s.serveHandlerKey(method, path, domain)
|
||||
}
|
||||
|
@ -15,18 +15,37 @@ import (
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
)
|
||||
|
||||
// 缓存数据项
|
||||
// handlerCacheItem is an item just for internal router searching cache.
|
||||
type handlerCacheItem struct {
|
||||
parsedItems []*handlerParsedItem
|
||||
hasHook bool
|
||||
hasServe bool
|
||||
}
|
||||
|
||||
// 查询请求处理方法.
|
||||
// 内部带锁机制,可以并发读,但是不能并发写;并且有缓存机制,按照Host、Method、Path进行缓存.
|
||||
// serveHandlerKey creates and returns a handler key for router.
|
||||
func (s *Server) serveHandlerKey(method, path, domain string) string {
|
||||
if len(domain) > 0 {
|
||||
domain = "@" + domain
|
||||
}
|
||||
if method == "" {
|
||||
return path + strings.ToLower(domain)
|
||||
}
|
||||
return strings.ToUpper(method) + ":" + path + strings.ToLower(domain)
|
||||
}
|
||||
|
||||
// getHandlersWithCache searches the router item with cache feature for given request.
|
||||
func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) {
|
||||
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} {
|
||||
parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost())
|
||||
method := r.Method
|
||||
// Special http method OPTIONS handling.
|
||||
// It searches the handler with the request method instead of OPTIONS method.
|
||||
if method == "OPTIONS" {
|
||||
if v := r.Request.Header.Get("Access-Control-Request-Method"); v != "" {
|
||||
method = v
|
||||
}
|
||||
}
|
||||
// Search and cache the router handlers.
|
||||
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(method, r.URL.Path, r.GetHost()), func() interface{} {
|
||||
parsedItems, hasHook, hasServe = s.searchHandlers(method, r.URL.Path, r.GetHost())
|
||||
if parsedItems != nil {
|
||||
return &handlerCacheItem{parsedItems, hasHook, hasServe}
|
||||
}
|
||||
@ -39,18 +58,14 @@ func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedI
|
||||
return
|
||||
}
|
||||
|
||||
// 路由注册方法检索,返回所有该路由的注册函数,构造成数组返回
|
||||
// searchHandlers retrieves and returns the routers with given parameters.
|
||||
// Note that the returned routers contain serving handler, middleware handlers and hook handlers.
|
||||
func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) {
|
||||
if len(path) == 0 {
|
||||
return nil, false, false
|
||||
}
|
||||
// 遍历检索的域名列表,优先遍历默认域名
|
||||
domains := []string{gDEFAULT_DOMAIN}
|
||||
if !strings.EqualFold(gDEFAULT_DOMAIN, domain) {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
// URL.Path层级拆分
|
||||
array := ([]string)(nil)
|
||||
// Split the URL.path to separate parts.
|
||||
var array []string
|
||||
if strings.EqualFold("/", path) {
|
||||
array = []string{"/"}
|
||||
} else {
|
||||
@ -58,85 +73,93 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
|
||||
}
|
||||
parsedItemList := glist.New()
|
||||
lastMiddlewareElem := (*glist.Element)(nil)
|
||||
repeatHandlerCheckMap := make(map[int]struct{})
|
||||
for _, domain := range domains {
|
||||
repeatHandlerCheckMap := make(map[int]struct{}, 16)
|
||||
// Default domain has the most priority when iteration.
|
||||
for _, domain := range []string{gDEFAULT_DOMAIN, domain} {
|
||||
p, ok := s.serveTree[domain]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 多层链表(每个节点都有一个*list链表)的目的是当叶子节点未有任何规则匹配时,让父级模糊匹配规则继续处理
|
||||
// Make a list array with capacity of 16.
|
||||
lists := make([]*glist.List, 0, 16)
|
||||
for k, v := range array {
|
||||
for i, part := range array {
|
||||
// In case of double '/' URI, eg: /user//index
|
||||
if v == "" {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
|
||||
// Add all list of each node to the list array.
|
||||
if v, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, v.(*glist.List))
|
||||
}
|
||||
if _, ok := p.(map[string]interface{})[v]; ok {
|
||||
p = p.(map[string]interface{})[v]
|
||||
if k == len(array)-1 {
|
||||
if _, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
|
||||
if v, ok := p.(map[string]interface{})[part]; ok {
|
||||
// Loop to the next node by certain key name.
|
||||
p = v
|
||||
if i == len(array)-1 {
|
||||
if v, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, v.(*glist.List))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := p.(map[string]interface{})["*fuzz"]; ok {
|
||||
p = p.(map[string]interface{})["*fuzz"]
|
||||
}
|
||||
} else if v, ok := p.(map[string]interface{})["*fuzz"]; ok {
|
||||
// Loop to the next node by fuzzy node item.
|
||||
p = v
|
||||
}
|
||||
// 如果是叶子节点,同时判断当前层级的"*fuzz"键名,解决例如:/user/*action 匹配 /user 的规则
|
||||
if k == len(array)-1 {
|
||||
if _, ok := p.(map[string]interface{})["*fuzz"]; ok {
|
||||
p = p.(map[string]interface{})["*fuzz"]
|
||||
if i == len(array)-1 {
|
||||
// It here also checks the fuzzy item,
|
||||
// for rule case like: "/user/*action" matches to "/user".
|
||||
if v, ok := p.(map[string]interface{})["*fuzz"]; ok {
|
||||
p = v
|
||||
}
|
||||
if _, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
|
||||
// The leaf must have a list item. It adds the list to the list array.
|
||||
if v, ok := p.(map[string]interface{})["*list"]; ok {
|
||||
lists = append(lists, v.(*glist.List))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 多层链表遍历检索,从数组末尾的链表开始遍历,末尾的深度高优先级也高
|
||||
// OK, let's loop the result list array, adding the handler item to the result handler result array.
|
||||
// As the tail of the list array has the most priority, it iterates the list array from its tail to head.
|
||||
for i := len(lists) - 1; i >= 0; i-- {
|
||||
for e := lists[i].Front(); e != nil; e = e.Next() {
|
||||
item := e.Value.(*handlerItem)
|
||||
// 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数)
|
||||
// Filter repeated handler item, especially the middleware and hook handlers.
|
||||
// It is necessary, do not remove this checks logic unless you really know how it is necessary.
|
||||
if _, ok := repeatHandlerCheckMap[item.itemId]; ok {
|
||||
continue
|
||||
} else {
|
||||
repeatHandlerCheckMap[item.itemId] = struct{}{}
|
||||
}
|
||||
// 服务路由函数只能添加一次,将重复判断放在这里提高检索效率
|
||||
// Serving handler can only be added to the handler array just once.
|
||||
if hasServe {
|
||||
switch item.itemType {
|
||||
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 动态匹配规则带有gDEFAULT_METHOD的情况,不会像静态规则那样直接解析为所有的HTTP METHOD存储
|
||||
if strings.EqualFold(item.router.Method, gDEFAULT_METHOD) || strings.EqualFold(item.router.Method, method) {
|
||||
// 注意当不带任何动态路由规则时,len(match) == 1
|
||||
if item.router.Method == gDEFAULT_METHOD || item.router.Method == method {
|
||||
// Note the rule having no fuzzy rules: len(match) == 1
|
||||
if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 {
|
||||
parsedItem := &handlerParsedItem{item, nil}
|
||||
// 如果需要路由规则中带有URI名称匹配,那么需要重新正则解析URL
|
||||
// If the rule contains fuzzy names,
|
||||
// it needs paring the URL to retrieve the values for the names.
|
||||
if len(item.router.RegNames) > 0 {
|
||||
if len(match) > len(item.router.RegNames) {
|
||||
parsedItem.values = make(map[string]string)
|
||||
// 如果存在存在同名路由参数名称,那么执行覆盖
|
||||
// It there repeated names, it just overwrites the same one.
|
||||
for i, name := range item.router.RegNames {
|
||||
parsedItem.values[name] = match[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
switch item.itemType {
|
||||
// 服务路由函数只能添加一次
|
||||
// The serving handler can be only added just once.
|
||||
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
|
||||
hasServe = true
|
||||
parsedItemList.PushBack(parsedItem)
|
||||
|
||||
// 中间件需要排序在链表中服务函数之前,并且多个中间件按照顺序添加以便于后续执行
|
||||
// The middleware is inserted before the serving handler.
|
||||
// If there're multiple middlewares, they're inserted into the result list by their registering order.
|
||||
// The middlewares are also executed by their registering order.
|
||||
case gHANDLER_TYPE_MIDDLEWARE:
|
||||
if lastMiddlewareElem == nil {
|
||||
lastMiddlewareElem = parsedItemList.PushFront(parsedItem)
|
||||
@ -144,7 +167,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
|
||||
lastMiddlewareElem = parsedItemList.InsertAfter(lastMiddlewareElem, parsedItem)
|
||||
}
|
||||
|
||||
// 钩子函数存在性判断
|
||||
// HOOK handler, just push it back to the list.
|
||||
case gHANDLER_TYPE_HOOK:
|
||||
hasHook = true
|
||||
parsedItemList.PushBack(parsedItem)
|
||||
@ -206,14 +229,3 @@ func (item *handlerItem) MarshalJSON() ([]byte, error) {
|
||||
func (item *handlerParsedItem) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(item.handler)
|
||||
}
|
||||
|
||||
// 生成回调方法查询的Key
|
||||
func (s *Server) serveHandlerKey(method, path, domain string) string {
|
||||
if len(domain) > 0 {
|
||||
domain = "@" + domain
|
||||
}
|
||||
if method == "" {
|
||||
return path + strings.ToLower(domain)
|
||||
}
|
||||
return strings.ToUpper(method) + ":" + path + strings.ToLower(domain)
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func (s *Server) BindObjectRest(pattern string, object interface{}) {
|
||||
}
|
||||
|
||||
func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc) {
|
||||
// Convert input method to map for convenience and high performance searching.
|
||||
// Convert input method to map for convenience and high performance searching purpose.
|
||||
var methodMap map[string]bool
|
||||
if len(method) > 0 {
|
||||
methodMap = make(map[string]bool)
|
||||
@ -86,12 +86,16 @@ func (s *Server) doBindObject(pattern string, object interface{}, method string,
|
||||
if !ok {
|
||||
if len(methodMap) > 0 {
|
||||
// 指定的方法名称注册,那么需要使用错误提示
|
||||
s.Logger().Errorf(`invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`,
|
||||
pkgPath, objName, methodName, v.Method(i).Type().String())
|
||||
s.Logger().Errorf(
|
||||
`invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`,
|
||||
pkgPath, objName, methodName, v.Method(i).Type().String(),
|
||||
)
|
||||
} else {
|
||||
// 否则只是Debug提示
|
||||
s.Logger().Debugf(`ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`,
|
||||
pkgPath, objName, methodName, v.Method(i).Type().String())
|
||||
s.Logger().Debugf(
|
||||
`ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`,
|
||||
pkgPath, objName, methodName, v.Method(i).Type().String(),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
43
net/ghttp/ghttp_unit_context_test.go
Normal file
43
net/ghttp/ghttp_unit_context_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_Context(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.Group("/", func(group *ghttp.RouterGroup) {
|
||||
group.Middleware(func(r *ghttp.Request) {
|
||||
r.Context.Set("traceid", 123)
|
||||
r.Middleware.Next()
|
||||
})
|
||||
group.GET("/", func(r *ghttp.Request) {
|
||||
r.Response.Write(r.Context.Get("traceid"))
|
||||
})
|
||||
})
|
||||
s.SetPort(p)
|
||||
//s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
gtest.Assert(client.GetContent("/"), `123`)
|
||||
})
|
||||
}
|
@ -594,8 +594,9 @@ func MiddlewareCORS(r *ghttp.Request) {
|
||||
func Test_Middleware_CORSAndAuth(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.Use(MiddlewareCORS)
|
||||
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
||||
group.Middleware(MiddlewareAuth, MiddlewareCORS)
|
||||
group.Middleware(MiddlewareAuth)
|
||||
group.POST("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Write("list")
|
||||
})
|
||||
@ -680,3 +681,35 @@ func Test_Middleware_Scope(t *testing.T) {
|
||||
gtest.Assert(client.GetContent("/scope3"), "ae3fb")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Middleware_Panic(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
i := 0
|
||||
s.Group("/", func(group *ghttp.RouterGroup) {
|
||||
group.Group("/", func(group *ghttp.RouterGroup) {
|
||||
group.Middleware(func(r *ghttp.Request) {
|
||||
i++
|
||||
panic("error")
|
||||
r.Middleware.Next()
|
||||
}, func(r *ghttp.Request) {
|
||||
i++
|
||||
r.Middleware.Next()
|
||||
})
|
||||
group.ALL("/", func(r *ghttp.Request) {
|
||||
r.Response.Write(i)
|
||||
})
|
||||
})
|
||||
})
|
||||
s.SetPort(p)
|
||||
//s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
gtest.Assert(client.GetContent("/"), "error")
|
||||
})
|
||||
}
|
76
net/ghttp/ghttp_unit_middleware_cors_test.go
Normal file
76
net/ghttp/ghttp_unit_middleware_cors_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_Middleware_CORS(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
||||
group.Middleware(MiddlewareCORS)
|
||||
group.POST("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Write("list")
|
||||
})
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
// Common Checks.
|
||||
gtest.Assert(client.GetContent("/"), "Not Found")
|
||||
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
|
||||
|
||||
// GET request does not any route.
|
||||
resp, err := client.Get("/api.v2/user/list")
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
|
||||
resp.Close()
|
||||
|
||||
// POST request matches the route and CORS middleware.
|
||||
resp, err = client.Post("/api.v2/user/list")
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
|
||||
gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With")
|
||||
gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE")
|
||||
gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*")
|
||||
gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800")
|
||||
resp.Close()
|
||||
})
|
||||
// OPTIONS GET
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
client.SetHeader("Access-Control-Request-Method", "GET")
|
||||
resp, err := client.Options("/api.v2/user/list")
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
|
||||
gtest.Assert(resp.ReadAllString(), "Not Found")
|
||||
resp.Close()
|
||||
})
|
||||
// OPTIONS POST
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
client.SetHeader("Access-Control-Request-Method", "POST")
|
||||
resp, err := client.Options("/api.v2/user/list")
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
|
||||
resp.Close()
|
||||
})
|
||||
}
|
176
net/ghttp/ghttp_unit_param_file_test.go
Normal file
176
net/ghttp/ghttp_unit_param_file_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gogf/gf/debug/gdebug"
|
||||
"github.com/gogf/gf/os/gfile"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_Params_File_Single(t *testing.T) {
|
||||
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.BindHandler("/upload/single", func(r *ghttp.Request) {
|
||||
file := r.GetUploadFile("file")
|
||||
if file == nil {
|
||||
r.Response.WriteExit("upload file cannot be empty")
|
||||
}
|
||||
|
||||
if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
|
||||
r.Response.WriteExit(name)
|
||||
}
|
||||
r.Response.WriteExit("upload failed")
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// normal name
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
|
||||
dstPath := gfile.Join(dstDirPath, "file1.txt")
|
||||
content := client.PostContent("/upload/single", g.Map{
|
||||
"file": "@file:" + srcPath,
|
||||
})
|
||||
gtest.AssertNE(content, "")
|
||||
gtest.AssertNE(content, "upload file cannot be empty")
|
||||
gtest.AssertNE(content, "upload failed")
|
||||
gtest.Assert(content, "file1.txt")
|
||||
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
|
||||
})
|
||||
// randomly rename.
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
|
||||
content := client.PostContent("/upload/single", g.Map{
|
||||
"file": "@file:" + srcPath,
|
||||
"randomlyRename": true,
|
||||
})
|
||||
dstPath := gfile.Join(dstDirPath, content)
|
||||
gtest.AssertNE(content, "")
|
||||
gtest.AssertNE(content, "upload file cannot be empty")
|
||||
gtest.AssertNE(content, "upload failed")
|
||||
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Params_File_CustomName(t *testing.T) {
|
||||
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.BindHandler("/upload/single", func(r *ghttp.Request) {
|
||||
file := r.GetUploadFile("file")
|
||||
if file == nil {
|
||||
r.Response.WriteExit("upload file cannot be empty")
|
||||
}
|
||||
file.Filename = "my.txt"
|
||||
if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
|
||||
r.Response.WriteExit(name)
|
||||
}
|
||||
r.Response.WriteExit("upload failed")
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
|
||||
dstPath := gfile.Join(dstDirPath, "my.txt")
|
||||
content := client.PostContent("/upload/single", g.Map{
|
||||
"file": "@file:" + srcPath,
|
||||
})
|
||||
gtest.AssertNE(content, "")
|
||||
gtest.AssertNE(content, "upload file cannot be empty")
|
||||
gtest.AssertNE(content, "upload failed")
|
||||
gtest.Assert(content, "my.txt")
|
||||
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Params_File_Batch(t *testing.T) {
|
||||
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.BindHandler("/upload/batch", func(r *ghttp.Request) {
|
||||
files := r.GetUploadFiles("file")
|
||||
if files == nil {
|
||||
r.Response.WriteExit("upload file cannot be empty")
|
||||
}
|
||||
if names, err := files.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
|
||||
r.Response.WriteExit(gstr.Join(names, ","))
|
||||
}
|
||||
r.Response.WriteExit("upload failed")
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// normal name
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
|
||||
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
|
||||
dstPath1 := gfile.Join(dstDirPath, "file1.txt")
|
||||
dstPath2 := gfile.Join(dstDirPath, "file2.txt")
|
||||
content := client.PostContent("/upload/batch", g.Map{
|
||||
"file[0]": "@file:" + srcPath1,
|
||||
"file[1]": "@file:" + srcPath2,
|
||||
})
|
||||
gtest.AssertNE(content, "")
|
||||
gtest.AssertNE(content, "upload file cannot be empty")
|
||||
gtest.AssertNE(content, "upload failed")
|
||||
gtest.Assert(content, "file1.txt,file2.txt")
|
||||
gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1))
|
||||
gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2))
|
||||
})
|
||||
// randomly rename.
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
|
||||
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
|
||||
content := client.PostContent("/upload/batch", g.Map{
|
||||
"file[0]": "@file:" + srcPath1,
|
||||
"file[1]": "@file:" + srcPath2,
|
||||
"randomlyRename": true,
|
||||
})
|
||||
gtest.AssertNE(content, "")
|
||||
gtest.AssertNE(content, "upload file cannot be empty")
|
||||
gtest.AssertNE(content, "upload failed")
|
||||
|
||||
array := gstr.SplitAndTrim(content, ",")
|
||||
gtest.Assert(len(array), 2)
|
||||
dstPath1 := gfile.Join(dstDirPath, array[0])
|
||||
dstPath2 := gfile.Join(dstDirPath, array[1])
|
||||
gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1))
|
||||
gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2))
|
||||
})
|
||||
}
|
48
net/ghttp/ghttp_unit_param_page_test.go
Normal file
48
net/ghttp/ghttp_unit_param_page_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_Params_Page(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.Group("/", func(group *ghttp.RouterGroup) {
|
||||
group.GET("/list", func(r *ghttp.Request) {
|
||||
page := r.GetPage(5, 2)
|
||||
r.Response.Write(page.GetContent(4))
|
||||
})
|
||||
group.GET("/list/{page}.html", func(r *ghttp.Request) {
|
||||
page := r.GetPage(5, 2)
|
||||
r.Response.Write(page.GetContent(4))
|
||||
})
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
gtest.Assert(client.GetContent("/list"), `<span class="GPageSpan">首页</span><span class="GPageSpan">上一页</span><span class="GPageSpan">1</span><a class="GPageLink" href="/list?page=2" title="2">2</a><a class="GPageLink" href="/list?page=3" title="3">3</a><a class="GPageLink" href="/list?page=2" title="">下一页</a><a class="GPageLink" href="/list?page=3" title="">尾页</a>`)
|
||||
gtest.Assert(client.GetContent("/list?page=3"), `<a class="GPageLink" href="/list?page=1" title="">首页</a><a class="GPageLink" href="/list?page=2" title="">上一页</a><a class="GPageLink" href="/list?page=1" title="1">1</a><a class="GPageLink" href="/list?page=2" title="2">2</a><span class="GPageSpan">3</span><span class="GPageSpan">下一页</span><span class="GPageSpan">尾页</span>`)
|
||||
|
||||
gtest.Assert(client.GetContent("/list/1.html"), `<span class="GPageSpan">首页</span><span class="GPageSpan">上一页</span><span class="GPageSpan">1</span><a class="GPageLink" href="/list/2.html" title="2">2</a><a class="GPageLink" href="/list/3.html" title="3">3</a><a class="GPageLink" href="/list/2.html" title="">下一页</a><a class="GPageLink" href="/list/3.html" title="">尾页</a>`)
|
||||
gtest.Assert(client.GetContent("/list/3.html"), `<a class="GPageLink" href="/list/1.html" title="">首页</a><a class="GPageLink" href="/list/2.html" title="">上一页</a><a class="GPageLink" href="/list/1.html" title="1">1</a><a class="GPageLink" href="/list/2.html" title="2">2</a><span class="GPageSpan">3</span><span class="GPageSpan">下一页</span><span class="GPageSpan">尾页</span>`)
|
||||
})
|
||||
}
|
@ -46,14 +46,6 @@ func (c *ControllerRest) Delete() {
|
||||
c.Response.Write("Controller Delete")
|
||||
}
|
||||
|
||||
func (c *ControllerRest) Patch() {
|
||||
c.Response.Write("Controller Patch")
|
||||
}
|
||||
|
||||
func (c *ControllerRest) Options() {
|
||||
c.Response.Write("Controller Options")
|
||||
}
|
||||
|
||||
func (c *ControllerRest) Head() {
|
||||
c.Response.Header().Set("head-ok", "1")
|
||||
}
|
||||
@ -78,8 +70,6 @@ func Test_Router_ControllerRest(t *testing.T) {
|
||||
gtest.Assert(client.PutContent("/"), "1Controller Put2")
|
||||
gtest.Assert(client.PostContent("/"), "1Controller Post2")
|
||||
gtest.Assert(client.DeleteContent("/"), "1Controller Delete2")
|
||||
gtest.Assert(client.PatchContent("/"), "1Controller Patch2")
|
||||
gtest.Assert(client.OptionsContent("/"), "1Controller Options2")
|
||||
resp1, err := client.Head("/")
|
||||
if err == nil {
|
||||
defer resp1.Close()
|
||||
@ -91,8 +81,6 @@ func Test_Router_ControllerRest(t *testing.T) {
|
||||
gtest.Assert(client.PutContent("/controller-rest/put"), "1Controller Put2")
|
||||
gtest.Assert(client.PostContent("/controller-rest/post"), "1Controller Post2")
|
||||
gtest.Assert(client.DeleteContent("/controller-rest/delete"), "1Controller Delete2")
|
||||
gtest.Assert(client.PatchContent("/controller-rest/patch"), "1Controller Patch2")
|
||||
gtest.Assert(client.OptionsContent("/controller-rest/options"), "1Controller Options2")
|
||||
resp2, err := client.Head("/controller-rest/head")
|
||||
if err == nil {
|
||||
defer resp2.Close()
|
||||
|
@ -70,7 +70,7 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) {
|
||||
r.Response.Write(r.Router.Uri)
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouterMap(false)
|
||||
//s.SetDumpRouterMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
|
||||
|
1
net/ghttp/testdata/upload/file1.txt
vendored
Normal file
1
net/ghttp/testdata/upload/file1.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
file1.txt: This file is for uploading unit test case.
|
1
net/ghttp/testdata/upload/file2.txt
vendored
Normal file
1
net/ghttp/testdata/upload/file2.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
file2.txt: This file is for uploading unit test case.
|
@ -10,6 +10,7 @@ package gfile
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
@ -30,17 +31,32 @@ const (
|
||||
var (
|
||||
// Default perm for file opening.
|
||||
DefaultPerm = os.FileMode(0666)
|
||||
|
||||
// The absolute file path for main package.
|
||||
// It can be only checked and set once.
|
||||
mainPkgPath = gtype.NewString()
|
||||
|
||||
// selfPath is the current running binary path.
|
||||
// As it is most commonly used, it is so defined as an internal package variable.
|
||||
selfPath = ""
|
||||
|
||||
// Temporary directory of system.
|
||||
tempDir = "/tmp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize internal package variable: tempDir.
|
||||
if !Exists(tempDir) {
|
||||
tempDir = os.TempDir()
|
||||
}
|
||||
// Initialize internal package variable: selfPath.
|
||||
selfPath, _ = exec.LookPath(os.Args[0])
|
||||
if selfPath != "" {
|
||||
selfPath, _ = filepath.Abs(selfPath)
|
||||
}
|
||||
if selfPath == "" {
|
||||
selfPath, _ = filepath.Abs(os.Args[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Mkdir creates directories recursively with given <path>.
|
||||
@ -58,22 +74,27 @@ func Mkdir(path string) error {
|
||||
func Create(path string) (*os.File, error) {
|
||||
dir := Dir(path)
|
||||
if !Exists(dir) {
|
||||
Mkdir(dir)
|
||||
if err := Mkdir(dir); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return os.Create(path)
|
||||
}
|
||||
|
||||
// Open opens file/directory readonly.
|
||||
// Open opens file/directory READONLY.
|
||||
func Open(path string) (*os.File, error) {
|
||||
return os.Open(path)
|
||||
}
|
||||
|
||||
// OpenFile opens file/directory with given <flag> and <perm>.
|
||||
// OpenFile opens file/directory with custom <flag> and <perm>.
|
||||
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
|
||||
func OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
return os.OpenFile(path, flag, perm)
|
||||
}
|
||||
|
||||
// OpenWithFlag opens file/directory with default perm and given <flag>.
|
||||
// OpenWithFlag opens file/directory with default perm and custom <flag>.
|
||||
// The default <perm> is 0666.
|
||||
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
|
||||
func OpenWithFlag(path string, flag int) (*os.File, error) {
|
||||
f, err := os.OpenFile(path, flag, DefaultPerm)
|
||||
if err != nil {
|
||||
@ -82,9 +103,11 @@ func OpenWithFlag(path string, flag int) (*os.File, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// OpenWithFlagPerm opens file/directory with given <flag> and <perm>.
|
||||
// OpenWithFlagPerm opens file/directory with custom <flag> and <perm>.
|
||||
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
|
||||
// The parameter <perm> is like: 0600, 0666, 0777, etc.
|
||||
func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
f, err := os.OpenFile(path, flag, os.FileMode(perm))
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -93,7 +116,14 @@ func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error)
|
||||
|
||||
// Join joins string array paths with file separator of current system.
|
||||
func Join(paths ...string) string {
|
||||
return strings.Join(paths, Separator)
|
||||
var s string
|
||||
for _, path := range paths {
|
||||
if s != "" {
|
||||
s += Separator
|
||||
}
|
||||
s += gstr.TrimRight(path, Separator)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Exists checks whether given <path> exist.
|
||||
@ -198,6 +228,7 @@ func Glob(pattern string, onlyNames ...bool) ([]string, error) {
|
||||
// Remove deletes all file/directory with <path> parameter.
|
||||
// If parameter <path> is directory, it deletes it recursively.
|
||||
func Remove(path string) error {
|
||||
//intlog.Print(`Remove:`, path)
|
||||
return os.RemoveAll(path)
|
||||
}
|
||||
|
||||
@ -268,17 +299,7 @@ func RealPath(path string) string {
|
||||
|
||||
// SelfPath returns absolute file path of current running process(binary).
|
||||
func SelfPath() string {
|
||||
path, _ := exec.LookPath(os.Args[0])
|
||||
if path != "" {
|
||||
path, _ = filepath.Abs(path)
|
||||
if path != "" {
|
||||
return path
|
||||
}
|
||||
}
|
||||
if path == "" {
|
||||
path, _ = filepath.Abs(os.Args[0])
|
||||
}
|
||||
return path
|
||||
return selfPath
|
||||
}
|
||||
|
||||
// SelfName returns file name of current running process(binary).
|
||||
|
@ -380,9 +380,9 @@ func ParseTimeFromContent(content string, format ...string) *Time {
|
||||
|
||||
// FuncCost calculates the cost time of function <f> in nanoseconds.
|
||||
func FuncCost(f func()) int64 {
|
||||
t := Nanosecond()
|
||||
t := TimestampNano()
|
||||
f()
|
||||
return Nanosecond() - t
|
||||
return TimestampNano() - t
|
||||
}
|
||||
|
||||
// isNumeric checks whether given <s> is a number.
|
||||
|
@ -157,6 +157,7 @@ func (t *Time) Nanosecond() int {
|
||||
return t.Time.Nanosecond()
|
||||
}
|
||||
|
||||
// String returns current time object as string.
|
||||
func (t *Time) String() string {
|
||||
if t == nil {
|
||||
return ""
|
||||
@ -167,14 +168,6 @@ func (t *Time) String() string {
|
||||
return t.Format("Y-m-d H:i:s")
|
||||
}
|
||||
|
||||
// String returns current time object as string.
|
||||
//func (t Time) String() string {
|
||||
// if t.IsZero() {
|
||||
// return ""
|
||||
// }
|
||||
// return t.Format("Y-m-d H:i:s")
|
||||
//}
|
||||
|
||||
// Clone returns a new Time object which is a clone of current time object.
|
||||
func (t *Time) Clone() *Time {
|
||||
return New(t.Time)
|
||||
|
@ -4,16 +4,18 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs running and management.
|
||||
// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs
|
||||
// running and management.
|
||||
//
|
||||
// This package is designed for management for millions of timing jobs.
|
||||
// The differences between gtime and gcron are as follows:
|
||||
// 1. gcron is implemented based on gtimer.
|
||||
// This package is designed for management for millions of timing jobs. The differences
|
||||
// between gtimer and gcron are as follows:
|
||||
// 1. package gcron is implemented based on package gtimer.
|
||||
// 2. gtimer is designed for high performance and for millions of timing jobs.
|
||||
// 3. gcron supports pattern grammar like linux crontab.
|
||||
// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured in microseconds.
|
||||
// 3. gcron supports configuration pattern grammar like linux crontab, which is more manually readable.
|
||||
// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured
|
||||
// in microseconds.
|
||||
//
|
||||
// Note the common delay of the timer: https://github.com/golang/go/issues/14410
|
||||
// ALSO VERY NOTE the common delay of the timer: https://github.com/golang/go/issues/14410
|
||||
package gtimer
|
||||
|
||||
import (
|
||||
@ -119,8 +121,10 @@ func DelayAddTimes(delay time.Duration, interval time.Duration, times int, job J
|
||||
defaultTimer.DelayAddTimes(delay, interval, times, job)
|
||||
}
|
||||
|
||||
// Exit is used in timing job, which exits and marks it closed from timer.
|
||||
// The timing job will be removed from timer later.
|
||||
// Exit is used in timing job internally, which exits and marks it closed from timer.
|
||||
// The timing job will be automatically removed from timer later. It uses "panic-recover"
|
||||
// mechanism internally implementing this feature, which is designed for simplification
|
||||
// and convenience.
|
||||
func Exit() {
|
||||
panic(gPANIC_EXIT)
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
var (
|
||||
regexMu = sync.RWMutex{}
|
||||
// Cache for regex object.
|
||||
// TODO There's no expiring logic for this map.
|
||||
// Note that there's no expiring logic for this map.
|
||||
regexMap = make(map[string]*regexp.Regexp)
|
||||
)
|
||||
|
||||
@ -22,29 +22,25 @@ var (
|
||||
// It uses cache to enhance the performance for compiling regular expression pattern,
|
||||
// which means, it will return the same *regexp.Regexp object with the same regular
|
||||
// expression pattern.
|
||||
func getRegexp(pattern string) (*regexp.Regexp, error) {
|
||||
if r := getCache(pattern); r != nil {
|
||||
return r, nil
|
||||
}
|
||||
if r, err := regexp.Compile(pattern); err == nil {
|
||||
setCache(pattern, r)
|
||||
return r, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// getCache returns *regexp.Regexp object from cache by given <pattern>, for internal usage.
|
||||
func getCache(pattern string) (regex *regexp.Regexp) {
|
||||
//
|
||||
// It is concurrent-safe for multiple goroutines.
|
||||
func getRegexp(pattern string) (regex *regexp.Regexp, err error) {
|
||||
// Retrieve the regular expression object using reading lock.
|
||||
regexMu.RLock()
|
||||
regex = regexMap[pattern]
|
||||
regexMu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// setCache stores *regexp.Regexp object into cache, for internal usage.
|
||||
func setCache(pattern string, regex *regexp.Regexp) {
|
||||
if regex != nil {
|
||||
return
|
||||
}
|
||||
// If it does not exist in the cache,
|
||||
// it compiles the pattern and creates one.
|
||||
regex, err = regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Cache the result object using writing lock.
|
||||
regexMu.Lock()
|
||||
regexMap[pattern] = regex
|
||||
regexMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -465,16 +465,16 @@ func SplitAndTrimSpace(str, delimiter string) []string {
|
||||
return array
|
||||
}
|
||||
|
||||
// Join concatenates the elements of a to create a single string. The separator string
|
||||
// sep is placed between elements in the resulting string.
|
||||
// Join concatenates the elements of <array> to create a single string. The separator string
|
||||
// <sep> is placed between elements in the resulting string.
|
||||
func Join(array []string, sep string) string {
|
||||
return strings.Join(array, sep)
|
||||
}
|
||||
|
||||
// JoinAny concatenates the elements of a to create a single string. The separator string
|
||||
// sep is placed between elements in the resulting string.
|
||||
// JoinAny concatenates the elements of <array> to create a single string. The separator string
|
||||
// <sep> is placed between elements in the resulting string.
|
||||
//
|
||||
// The parameter <array> can be any type of slice.
|
||||
// The parameter <array> can be any type of slice, which be converted to string array.
|
||||
func JoinAny(array interface{}, sep string) string {
|
||||
return strings.Join(gconv.Strings(array), sep)
|
||||
}
|
||||
|
@ -48,7 +48,8 @@ var (
|
||||
)
|
||||
|
||||
// Convert converts the variable <i> to the type <t>, the type <t> is specified by string.
|
||||
// The optional parameter <params> is used for additional parameter passing.
|
||||
// The optional parameter <params> is used for additional necessary parameter for this conversion.
|
||||
// It supports common types conversion as its conversion based on type name string.
|
||||
func Convert(i interface{}, t string, params ...interface{}) interface{} {
|
||||
switch t {
|
||||
case "int":
|
||||
@ -121,6 +122,7 @@ func Convert(i interface{}, t string, params ...interface{}) interface{} {
|
||||
case "Duration", "time.Duration":
|
||||
return Duration(i)
|
||||
default:
|
||||
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ func MapDeep(value interface{}, tags ...string) map[string]interface{} {
|
||||
}
|
||||
|
||||
// doMapConvert implements the map converting.
|
||||
// It automatically checks and converts json string to map if <value> is string/[]byte.
|
||||
func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
@ -122,7 +123,7 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]
|
||||
for k, v := range r {
|
||||
m[String(k)] = v
|
||||
}
|
||||
// Not a common type, then use reflection.
|
||||
// Not a common type, it then uses reflection for conversion.
|
||||
default:
|
||||
rv := reflect.ValueOf(value)
|
||||
kind := rv.Kind()
|
||||
@ -132,6 +133,20 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
// If <value> is type of array, it converts the value of even number index as its key and
|
||||
// the value of odd number index as its corresponding value.
|
||||
// Eg:
|
||||
// []string{"k1","v1","k2","v2"} => map[string]interface{}{"k1":"v1", "k2":"v2"}
|
||||
// []string{"k1","v1","k2"} => map[string]interface{}{"k1":"v1", "k2":nil}
|
||||
case reflect.Slice, reflect.Array:
|
||||
length := rv.Len()
|
||||
for i := 0; i < length; i += 2 {
|
||||
if i+1 < length {
|
||||
m[String(rv.Index(i).Interface())] = rv.Index(i + 1).Interface()
|
||||
} else {
|
||||
m[String(rv.Index(i).Interface())] = nil
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
ks := rv.MapKeys()
|
||||
for _, k := range ks {
|
||||
|
@ -242,6 +242,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
|
||||
if !structFieldValue.CanSet() {
|
||||
return nil
|
||||
}
|
||||
// If any panic, it secondly uses reflect conversion and assignment.
|
||||
defer func() {
|
||||
if recover() != nil {
|
||||
err = bindVarToReflectValue(structFieldValue, value)
|
||||
@ -250,6 +251,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
|
||||
if empty.IsNil(value) {
|
||||
structFieldValue.Set(reflect.Zero(structFieldValue.Type()))
|
||||
} else {
|
||||
// It firstly simply assigns the value to the attribute.
|
||||
structFieldValue.Set(reflect.ValueOf(Convert(value, structFieldValue.Type().String())))
|
||||
}
|
||||
return nil
|
||||
@ -260,7 +262,8 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
|
||||
switch structFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
if err := Struct(value, structFieldValue); err != nil {
|
||||
structFieldValue.Set(reflect.ValueOf(value))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
|
||||
}
|
||||
// Note that the slice element might be type of struct,
|
||||
// so it uses Struct function doing the converting internally.
|
||||
@ -275,13 +278,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
|
||||
if t.Kind() == reflect.Ptr {
|
||||
e := reflect.New(t.Elem()).Elem()
|
||||
if err := Struct(v.Index(i).Interface(), e); err != nil {
|
||||
e.Set(reflect.ValueOf(v.Index(i).Interface()))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t))
|
||||
}
|
||||
a.Index(i).Set(e.Addr())
|
||||
} else {
|
||||
e := reflect.New(t).Elem()
|
||||
if err := Struct(v.Index(i).Interface(), e); err != nil {
|
||||
e.Set(reflect.ValueOf(v.Index(i).Interface()))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t))
|
||||
}
|
||||
a.Index(i).Set(e)
|
||||
}
|
||||
@ -293,13 +298,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
|
||||
if t.Kind() == reflect.Ptr {
|
||||
e := reflect.New(t.Elem()).Elem()
|
||||
if err := Struct(value, e); err != nil {
|
||||
e.Set(reflect.ValueOf(value))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
e.Set(reflect.ValueOf(value).Convert(t))
|
||||
}
|
||||
a.Index(0).Set(e.Addr())
|
||||
} else {
|
||||
e := reflect.New(t).Elem()
|
||||
if err := Struct(value, e); err != nil {
|
||||
e.Set(reflect.ValueOf(value))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
e.Set(reflect.ValueOf(value).Convert(t))
|
||||
}
|
||||
a.Index(0).Set(e)
|
||||
}
|
||||
@ -311,34 +318,40 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
|
||||
// Assign value with interface Set.
|
||||
// Note that only pointer can implement interface Set.
|
||||
if v, ok := item.Interface().(apiUnmarshalValue); ok {
|
||||
v.UnmarshalValue(value)
|
||||
err = v.UnmarshalValue(value)
|
||||
structFieldValue.Set(item)
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
elem := item.Elem()
|
||||
if err = bindVarToReflectValue(elem, value); err == nil {
|
||||
structFieldValue.Set(elem.Addr())
|
||||
}
|
||||
|
||||
// It mainly and specially handles the interface of nil value.
|
||||
case reflect.Interface:
|
||||
if value == nil {
|
||||
// Specially.
|
||||
structFieldValue.Set(reflect.ValueOf((*interface{})(nil)))
|
||||
} else {
|
||||
structFieldValue.Set(reflect.ValueOf(value))
|
||||
// Note there's reflect conversion mechanism here.
|
||||
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
|
||||
}
|
||||
|
||||
default:
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = errors.New(
|
||||
fmt.Sprintf(`cannot convert "%d" to type "%s"`,
|
||||
fmt.Sprintf(`cannot convert value "%d" to type "%s"`,
|
||||
value,
|
||||
structFieldValue.Type().String(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}()
|
||||
structFieldValue.Set(reflect.ValueOf(value))
|
||||
// It here uses reflect converting <value> to type of the attribute and assigns
|
||||
// the result value to the attribute. It might fail and panic if the usual Go
|
||||
// conversion rules do not allow conversion.
|
||||
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -37,6 +37,23 @@ func Test_Map_Basic(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Map_Slice(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
slice1 := g.Slice{"1", "2", "3", "4"}
|
||||
slice2 := g.Slice{"1", "2", "3"}
|
||||
slice3 := g.Slice{}
|
||||
gtest.Assert(gconv.Map(slice1), g.Map{
|
||||
"1": "2",
|
||||
"3": "4",
|
||||
})
|
||||
gtest.Assert(gconv.Map(slice2), g.Map{
|
||||
"1": "2",
|
||||
"3": nil,
|
||||
})
|
||||
gtest.Assert(gconv.Map(slice3), g.Map{})
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Map_StructWithGconvTag(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
type User struct {
|
||||
|
@ -324,6 +324,36 @@ func Test_Struct_Attr_Struct_Slice_Ptr(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Struct_Attr_CustomType1(t *testing.T) {
|
||||
type MyInt int
|
||||
type User struct {
|
||||
Id MyInt
|
||||
Name string
|
||||
}
|
||||
gtest.Case(t, func() {
|
||||
user := new(User)
|
||||
err := gconv.Struct(g.Map{"id": 1, "name": "john"}, user)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(user.Id, 1)
|
||||
gtest.Assert(user.Name, "john")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Struct_Attr_CustomType2(t *testing.T) {
|
||||
type MyInt int
|
||||
type User struct {
|
||||
Id []MyInt
|
||||
Name string
|
||||
}
|
||||
gtest.Case(t, func() {
|
||||
user := new(User)
|
||||
err := gconv.Struct(g.Map{"id": g.Slice{1, 2}, "name": "john"}, user)
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(user.Id, g.Slice{1, 2})
|
||||
gtest.Assert(user.Name, "john")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Struct_PrivateAttribute(t *testing.T) {
|
||||
type User struct {
|
||||
Id int
|
||||
|
@ -9,292 +9,217 @@ package gpage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
url2 "net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
"github.com/gogf/gf/text/gregex"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"github.com/gogf/gf/util/gconv"
|
||||
"math"
|
||||
)
|
||||
|
||||
// 分页对象
|
||||
// Page is the pagination implementer.
|
||||
// All the attributes are public, you can change them when necessary.
|
||||
type Page struct {
|
||||
Url *url2.URL // 当前页面的URL对象
|
||||
Router *ghttp.Router // 当前页面的路由对象(与gf框架耦合,在静态分页下有效)
|
||||
UrlTemplate string // URL生成规则,内部可使用{.page}变量指定页码
|
||||
TotalSize int // 总共数据条数
|
||||
TotalPage int // 总页数
|
||||
CurrentPage int // 当前页码
|
||||
PageName string // 分页参数名称(GET参数)
|
||||
NextPageTag string // 下一页标签
|
||||
PrevPageTag string // 上一页标签
|
||||
FirstPageTag string // 首页标签
|
||||
LastPageTag string // 尾页标签
|
||||
PrevBar string // 上一分页条
|
||||
NextBar string // 下一分页条
|
||||
PageBarNum int // 控制分页条的数量
|
||||
AjaxActionName string // AJAX方法名,当该属性有值时,表示使用AJAX分页
|
||||
TotalSize int // Total size.
|
||||
TotalPage int // Total page, which is automatically calculated.
|
||||
CurrentPage int // Current page number >= 1.
|
||||
UrlTemplate string // Custom url template for page url producing.
|
||||
LinkStyle string // CSS style name for HTML link tag <a>.
|
||||
SpanStyle string // CSS style name for HTML span tag <span>, which is used for first, current and last page tag.
|
||||
SelectStyle string // CSS style name for HTML select tag <select>.
|
||||
NextPageTag string // Tag name for next p.
|
||||
PrevPageTag string // Tag name for prev p.
|
||||
FirstPageTag string // Tag name for first p.
|
||||
LastPageTag string // Tag name for last p.
|
||||
PrevBarTag string // Tag string for prev bar.
|
||||
NextBarTag string // Tag string for next bar.
|
||||
PageBarNum int // Page bar number for displaying.
|
||||
AjaxActionName string // Ajax function name. Ajax is enabled if this attribute is not empty.
|
||||
}
|
||||
|
||||
// 创建一个分页对象,输入参数分别为:
|
||||
// 总数量、每页数量、当前页码、当前的URL(URI+QUERY)、(可选)路由规则(例如: /user/list/:page、/order/list/*page、/order/list/{page}.html)
|
||||
func New(TotalSize, perPage int, CurrentPage interface{}, url string, router ...*ghttp.Router) *Page {
|
||||
u, _ := url2.Parse(url)
|
||||
page := &Page{
|
||||
PageName: "page",
|
||||
const (
|
||||
PAGE_NAME = "page" // PAGE_NAME defines the default page name.
|
||||
PAGE_PLACE_HOLDER = "{.page}" // PAGE_PLACE_HOLDER defines the place holder for the url template.
|
||||
)
|
||||
|
||||
// New creates and returns a pagination manager.
|
||||
// Note that the parameter <urlTemplate> specifies the URL producing template, like:
|
||||
// /user/list/{.page}, /user/list/{.page}.html, /user/list?page={.page}&type=1, etc.
|
||||
// The build-in variable in <urlTemplate> "{.page}" specifies the page number, which will be replaced by certain
|
||||
// page number when producing.
|
||||
func New(totalSize, pageSize, currentPage int, urlTemplate string) *Page {
|
||||
p := &Page{
|
||||
LinkStyle: "GPageLink",
|
||||
SpanStyle: "GPageSpan",
|
||||
SelectStyle: "GPageSelect",
|
||||
PrevPageTag: "<",
|
||||
NextPageTag: ">",
|
||||
FirstPageTag: "|<",
|
||||
LastPageTag: ">|",
|
||||
PrevBar: "<<",
|
||||
NextBar: ">>",
|
||||
TotalSize: TotalSize,
|
||||
TotalPage: int(math.Ceil(float64(TotalSize) / float64(perPage))),
|
||||
CurrentPage: 1,
|
||||
PrevBarTag: "<<",
|
||||
NextBarTag: ">>",
|
||||
TotalSize: totalSize,
|
||||
TotalPage: int(math.Ceil(float64(totalSize) / float64(pageSize))),
|
||||
CurrentPage: currentPage,
|
||||
PageBarNum: 10,
|
||||
Url: u,
|
||||
UrlTemplate: urlTemplate,
|
||||
}
|
||||
curPage := gconv.Int(CurrentPage)
|
||||
if curPage > 0 {
|
||||
page.CurrentPage = curPage
|
||||
if currentPage == 0 {
|
||||
p.CurrentPage = 1
|
||||
}
|
||||
if len(router) > 0 {
|
||||
page.Router = router[0]
|
||||
}
|
||||
return page
|
||||
return p
|
||||
}
|
||||
|
||||
// 启用AJAX分页
|
||||
func (page *Page) EnableAjax(actionName string) {
|
||||
page.AjaxActionName = actionName
|
||||
// NextPage returns the HTML content for the next page.
|
||||
func (p *Page) NextPage() string {
|
||||
if p.CurrentPage < p.TotalPage {
|
||||
return p.GetLink(p.CurrentPage+1, p.NextPageTag, "")
|
||||
}
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.NextPageTag)
|
||||
}
|
||||
|
||||
// 设置URL生成规则模板,模板中可使用{.page}变量指定页码位置
|
||||
func (page *Page) SetUrlTemplate(template string) {
|
||||
page.UrlTemplate = template
|
||||
// PrevPage returns the HTML content for the previous page.
|
||||
func (p *Page) PrevPage() string {
|
||||
if p.CurrentPage > 1 {
|
||||
return p.GetLink(p.CurrentPage-1, p.PrevPageTag, "")
|
||||
}
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.PrevPageTag)
|
||||
}
|
||||
|
||||
// 获取显示"下一页"的内容.
|
||||
func (page *Page) NextPage(styles ...string) string {
|
||||
var curStyle, style string
|
||||
if len(styles) > 0 {
|
||||
curStyle = styles[0]
|
||||
// FirstPage returns the HTML content for the first page.
|
||||
func (p *Page) FirstPage() string {
|
||||
if p.CurrentPage == 1 {
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.FirstPageTag)
|
||||
}
|
||||
if len(styles) > 1 {
|
||||
style = styles[0]
|
||||
}
|
||||
if page.CurrentPage < page.TotalPage {
|
||||
return page.GetLink(page.GetUrl(page.CurrentPage+1), page.NextPageTag, "下一页", style)
|
||||
}
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.NextPageTag)
|
||||
return p.GetLink(1, p.FirstPageTag, "")
|
||||
}
|
||||
|
||||
// 获取显示“上一页”的内容
|
||||
func (page *Page) PrevPage(styles ...string) string {
|
||||
var curStyle, style string
|
||||
if len(styles) > 0 {
|
||||
curStyle = styles[0]
|
||||
// LastPage returns the HTML content for the last page.
|
||||
func (p *Page) LastPage() string {
|
||||
if p.CurrentPage == p.TotalPage {
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.LastPageTag)
|
||||
}
|
||||
if len(styles) > 1 {
|
||||
style = styles[0]
|
||||
}
|
||||
if page.CurrentPage > 1 {
|
||||
return page.GetLink(page.GetUrl(page.CurrentPage-1), page.PrevPageTag, "上一页", style)
|
||||
}
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.PrevPageTag)
|
||||
return p.GetLink(p.TotalPage, p.LastPageTag, "")
|
||||
}
|
||||
|
||||
// 获取显示“首页”的代码
|
||||
func (page *Page) FirstPage(styles ...string) string {
|
||||
var curStyle, style string
|
||||
if len(styles) > 0 {
|
||||
curStyle = styles[0]
|
||||
// PageBar returns the HTML page bar content with link and span tags.
|
||||
func (p *Page) PageBar() string {
|
||||
plus := int(math.Ceil(float64(p.PageBarNum / 2)))
|
||||
if p.PageBarNum-plus+p.CurrentPage > p.TotalPage {
|
||||
plus = p.PageBarNum - p.TotalPage + p.CurrentPage
|
||||
}
|
||||
if len(styles) > 1 {
|
||||
style = styles[0]
|
||||
}
|
||||
if page.CurrentPage == 1 {
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.FirstPageTag)
|
||||
}
|
||||
return page.GetLink(page.GetUrl(1), page.FirstPageTag, "第一页", style)
|
||||
}
|
||||
|
||||
// 获取显示“尾页”的内容
|
||||
func (page *Page) LastPage(styles ...string) string {
|
||||
var curStyle, style string
|
||||
if len(styles) > 0 {
|
||||
curStyle = styles[0]
|
||||
}
|
||||
if len(styles) > 1 {
|
||||
style = styles[0]
|
||||
}
|
||||
if page.CurrentPage == page.TotalPage {
|
||||
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.LastPageTag)
|
||||
}
|
||||
return page.GetLink(page.GetUrl(page.TotalPage), page.LastPageTag, "最后页", style)
|
||||
}
|
||||
|
||||
// 获得分页条列表内容
|
||||
func (page *Page) PageBar(styles ...string) string {
|
||||
var curStyle, style string
|
||||
if len(styles) > 0 {
|
||||
curStyle = styles[0]
|
||||
}
|
||||
if len(styles) > 1 {
|
||||
style = styles[0]
|
||||
}
|
||||
plus := int(math.Ceil(float64(page.PageBarNum / 2)))
|
||||
if page.PageBarNum-plus+page.CurrentPage > page.TotalPage {
|
||||
plus = page.PageBarNum - page.TotalPage + page.CurrentPage
|
||||
}
|
||||
begin := page.CurrentPage - plus + 1
|
||||
begin := p.CurrentPage - plus + 1
|
||||
if begin < 1 {
|
||||
begin = 1
|
||||
}
|
||||
ret := ""
|
||||
for i := begin; i < begin+page.PageBarNum; i++ {
|
||||
if i <= page.TotalPage {
|
||||
if i != page.CurrentPage {
|
||||
ret += page.GetLink(page.GetUrl(i), gconv.String(i), style, "")
|
||||
barContent := ""
|
||||
for i := begin; i < begin+p.PageBarNum; i++ {
|
||||
if i <= p.TotalPage {
|
||||
if i != p.CurrentPage {
|
||||
barText := gconv.String(i)
|
||||
barContent += p.GetLink(i, barText, barText)
|
||||
} else {
|
||||
ret += fmt.Sprintf(`<span class="%s">%d</span>`, curStyle, i)
|
||||
barContent += fmt.Sprintf(`<span class="%s">%d</span>`, p.SpanStyle, i)
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ret
|
||||
return barContent
|
||||
}
|
||||
|
||||
// 获取基于select标签的显示跳转按钮的代码
|
||||
func (page *Page) SelectBar() string {
|
||||
ret := `<select name="gpage_select" onchange="window.location.href=this.value">`
|
||||
for i := 1; i <= page.TotalPage; i++ {
|
||||
if i == page.CurrentPage {
|
||||
ret += fmt.Sprintf(`<option value="%s" selected>%d</option>`, page.GetUrl(i), i)
|
||||
// SelectBar returns the select HTML content for pagination.
|
||||
func (p *Page) SelectBar() string {
|
||||
barContent := fmt.Sprintf(`<select name="%s" onchange="window.location.href=this.value">`, p.SelectStyle)
|
||||
for i := 1; i <= p.TotalPage; i++ {
|
||||
if i == p.CurrentPage {
|
||||
barContent += fmt.Sprintf(`<option value="%s" selected>%d</option>`, p.GetUrl(i), i)
|
||||
} else {
|
||||
ret += fmt.Sprintf(`<option value="%s">%d</option>`, page.GetUrl(i), i)
|
||||
barContent += fmt.Sprintf(`<option value="%s">%d</option>`, p.GetUrl(i), i)
|
||||
}
|
||||
}
|
||||
ret += "</select>"
|
||||
return ret
|
||||
barContent += "</select>"
|
||||
return barContent
|
||||
}
|
||||
|
||||
// 预定义的分页显示风格内容
|
||||
func (page *Page) GetContent(mode int) string {
|
||||
// GetContent returns the page content for predefined mode.
|
||||
// These predefined contents are mainly for chinese localization purpose. You can defines your own
|
||||
// page function retrieving the page content according to the implementation of this function.
|
||||
func (p *Page) GetContent(mode int) string {
|
||||
switch mode {
|
||||
case 1:
|
||||
page.NextPageTag = "下一页"
|
||||
page.PrevPageTag = "上一页"
|
||||
p.NextPageTag = "下一页"
|
||||
p.PrevPageTag = "上一页"
|
||||
return fmt.Sprintf(
|
||||
`%s <span class="current">%d</span> %s`,
|
||||
page.PrevPage(),
|
||||
page.CurrentPage,
|
||||
page.NextPage(),
|
||||
p.PrevPage(),
|
||||
p.CurrentPage,
|
||||
p.NextPage(),
|
||||
)
|
||||
|
||||
case 2:
|
||||
page.NextPageTag = "下一页>>"
|
||||
page.PrevPageTag = "<<上一页"
|
||||
page.FirstPageTag = "首页"
|
||||
page.LastPageTag = "尾页"
|
||||
p.NextPageTag = "下一页>>"
|
||||
p.PrevPageTag = "<<上一页"
|
||||
p.FirstPageTag = "首页"
|
||||
p.LastPageTag = "尾页"
|
||||
return fmt.Sprintf(
|
||||
`%s%s<span class="current">[第%d页]</span>%s%s第%s页`,
|
||||
page.FirstPage(),
|
||||
page.PrevPage(),
|
||||
page.CurrentPage,
|
||||
page.NextPage(),
|
||||
page.LastPage(),
|
||||
page.SelectBar(),
|
||||
p.FirstPage(),
|
||||
p.PrevPage(),
|
||||
p.CurrentPage,
|
||||
p.NextPage(),
|
||||
p.LastPage(),
|
||||
p.SelectBar(),
|
||||
)
|
||||
|
||||
case 3:
|
||||
page.NextPageTag = "下一页"
|
||||
page.PrevPageTag = "上一页"
|
||||
page.FirstPageTag = "首页"
|
||||
page.LastPageTag = "尾页"
|
||||
pageStr := page.FirstPage()
|
||||
pageStr += page.PrevPage()
|
||||
pageStr += page.PageBar("current")
|
||||
pageStr += page.NextPage()
|
||||
pageStr += page.LastPage()
|
||||
p.NextPageTag = "下一页"
|
||||
p.PrevPageTag = "上一页"
|
||||
p.FirstPageTag = "首页"
|
||||
p.LastPageTag = "尾页"
|
||||
pageStr := p.FirstPage()
|
||||
pageStr += p.PrevPage()
|
||||
pageStr += p.PageBar()
|
||||
pageStr += p.NextPage()
|
||||
pageStr += p.LastPage()
|
||||
pageStr += fmt.Sprintf(
|
||||
`<span>当前页%d/%d</span> <span>共%d条</span>`,
|
||||
page.CurrentPage,
|
||||
page.TotalPage,
|
||||
page.TotalSize,
|
||||
p.CurrentPage,
|
||||
p.TotalPage,
|
||||
p.TotalSize,
|
||||
)
|
||||
return pageStr
|
||||
|
||||
case 4:
|
||||
page.NextPageTag = "下一页"
|
||||
page.PrevPageTag = "上一页"
|
||||
page.FirstPageTag = "首页"
|
||||
page.LastPageTag = "尾页"
|
||||
pageStr := page.FirstPage()
|
||||
pageStr += page.PrevPage()
|
||||
pageStr += page.PageBar("current")
|
||||
pageStr += page.NextPage()
|
||||
pageStr += page.LastPage()
|
||||
p.NextPageTag = "下一页"
|
||||
p.PrevPageTag = "上一页"
|
||||
p.FirstPageTag = "首页"
|
||||
p.LastPageTag = "尾页"
|
||||
pageStr := p.FirstPage()
|
||||
pageStr += p.PrevPage()
|
||||
pageStr += p.PageBar()
|
||||
pageStr += p.NextPage()
|
||||
pageStr += p.LastPage()
|
||||
return pageStr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 为指定的页面返回地址值
|
||||
func (page *Page) GetUrl(pageNo int) string {
|
||||
// 复制一个URL对象
|
||||
url := *page.Url
|
||||
if len(page.UrlTemplate) == 0 && page.Router != nil {
|
||||
page.UrlTemplate = page.makeUrlTemplate(url.Path, page.Router)
|
||||
}
|
||||
if len(page.UrlTemplate) > 0 {
|
||||
// 指定URL生成模板
|
||||
url.Path = gstr.Replace(page.UrlTemplate, "{.page}", gconv.String(pageNo))
|
||||
return url.String()
|
||||
}
|
||||
|
||||
values := page.Url.Query()
|
||||
values.Set(page.PageName, gconv.String(pageNo))
|
||||
url.RawQuery = values.Encode()
|
||||
return url.String()
|
||||
// GetUrl parses the UrlTemplate with given page number and returns the URL string.
|
||||
// Note that the UrlTemplate attribute can be either an URL or a URI string with "{.page}"
|
||||
// place holder specifying the page number position.
|
||||
func (p *Page) GetUrl(page int) string {
|
||||
return gstr.Replace(p.UrlTemplate, PAGE_PLACE_HOLDER, gconv.String(page))
|
||||
}
|
||||
|
||||
// 根据当前URL以及注册路由信息计算出对应的URL模板
|
||||
func (page *Page) makeUrlTemplate(url string, router *ghttp.Router) (tpl string) {
|
||||
if page.Router != nil && len(router.RegNames) > 0 {
|
||||
if match, err := gregex.MatchString(router.RegRule, url); err == nil && len(match) > 0 {
|
||||
if len(match) > len(router.RegNames) {
|
||||
tpl = router.Uri
|
||||
hasPageName := false
|
||||
for i, name := range router.RegNames {
|
||||
rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name)
|
||||
if !hasPageName && strings.Compare(name, page.PageName) == 0 {
|
||||
hasPageName = true
|
||||
tpl, _ = gregex.ReplaceString(rule, `{.page}`, tpl)
|
||||
} else {
|
||||
tpl, _ = gregex.ReplaceString(rule, match[i+1], tpl)
|
||||
}
|
||||
}
|
||||
if !hasPageName {
|
||||
tpl = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 获取链接地址
|
||||
func (page *Page) GetLink(url, text, title, style string) string {
|
||||
if len(style) > 0 {
|
||||
style = fmt.Sprintf(`class="%s" `, style)
|
||||
}
|
||||
if len(page.AjaxActionName) > 0 {
|
||||
return fmt.Sprintf(`<a %shref='#' onclick="%s('%s')">%s</a>`, style, page.AjaxActionName, url, text)
|
||||
// GetLink returns the HTML link tag <a> content for given page number.
|
||||
func (p *Page) GetLink(page int, text, title string) string {
|
||||
if len(p.AjaxActionName) > 0 {
|
||||
return fmt.Sprintf(
|
||||
`<a class="%s" href="javascript:%s('%s')" title="%s">%s</a>`,
|
||||
p.LinkStyle, p.AjaxActionName, p.GetUrl(page), title, text,
|
||||
)
|
||||
} else {
|
||||
return fmt.Sprintf(`<a %shref="%s" title="%s">%s</a>`, style, url, title, text)
|
||||
return fmt.Sprintf(
|
||||
`<a class="%s" href="%s" title="%s">%s</a>`,
|
||||
p.LinkStyle, p.GetUrl(page), title, text,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
116
util/gpage/gpage_unit_test.go
Normal file
116
util/gpage/gpage_unit_test.go
Normal file
@ -0,0 +1,116 @@
|
||||
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// go test *.go -bench=".*"
|
||||
|
||||
package gpage_test
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/util/gpage"
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(9, 2, 1, `/user/list?page={.page}`)
|
||||
gtest.Assert(page.TotalSize, 9)
|
||||
gtest.Assert(page.TotalPage, 5)
|
||||
gtest.Assert(page.CurrentPage, 1)
|
||||
})
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(9, 2, 0, `/user/list?page={.page}`)
|
||||
gtest.Assert(page.TotalSize, 9)
|
||||
gtest.Assert(page.TotalPage, 5)
|
||||
gtest.Assert(page.CurrentPage, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Basic(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(9, 2, 1, `/user/list?page={.page}`)
|
||||
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list?page=2" title="">></a>`)
|
||||
gtest.Assert(page.PrevPage(), `<span class="GPageSpan"><</span>`)
|
||||
gtest.Assert(page.FirstPage(), `<span class="GPageSpan">|<</span>`)
|
||||
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list?page=5" title="">>|</a>`)
|
||||
gtest.Assert(page.PageBar(), `<span class="GPageSpan">1</span><a class="GPageLink" href="/user/list?page=2" title="2">2</a><a class="GPageLink" href="/user/list?page=3" title="3">3</a><a class="GPageLink" href="/user/list?page=4" title="4">4</a><a class="GPageLink" href="/user/list?page=5" title="5">5</a>`)
|
||||
})
|
||||
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(9, 2, 3, `/user/list?page={.page}`)
|
||||
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list?page=4" title="">></a>`)
|
||||
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list?page=2" title=""><</a>`)
|
||||
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list?page=1" title="">|<</a>`)
|
||||
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list?page=5" title="">>|</a>`)
|
||||
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list?page=1" title="1">1</a><a class="GPageLink" href="/user/list?page=2" title="2">2</a><span class="GPageSpan">3</span><a class="GPageLink" href="/user/list?page=4" title="4">4</a><a class="GPageLink" href="/user/list?page=5" title="5">5</a>`)
|
||||
})
|
||||
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(9, 2, 5, `/user/list?page={.page}`)
|
||||
gtest.Assert(page.NextPage(), `<span class="GPageSpan">></span>`)
|
||||
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list?page=4" title=""><</a>`)
|
||||
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list?page=1" title="">|<</a>`)
|
||||
gtest.Assert(page.LastPage(), `<span class="GPageSpan">>|</span>`)
|
||||
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list?page=1" title="1">1</a><a class="GPageLink" href="/user/list?page=2" title="2">2</a><a class="GPageLink" href="/user/list?page=3" title="3">3</a><a class="GPageLink" href="/user/list?page=4" title="4">4</a><span class="GPageSpan">5</span>`)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CustomTag(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
|
||||
page.PrevPageTag = "《"
|
||||
page.NextPageTag = "》"
|
||||
page.FirstPageTag = "|《"
|
||||
page.LastPageTag = "》|"
|
||||
page.PrevBarTag = "《《"
|
||||
page.NextBarTag = "》》"
|
||||
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list/3" title="">》</a>`)
|
||||
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list/1" title="">《</a>`)
|
||||
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list/1" title="">|《</a>`)
|
||||
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list/5" title="">》|</a>`)
|
||||
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list/1" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="/user/list/3" title="3">3</a><a class="GPageLink" href="/user/list/4" title="4">4</a><a class="GPageLink" href="/user/list/5" title="5">5</a>`)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_CustomStyle(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
|
||||
page.LinkStyle = "MyPageLink"
|
||||
page.SpanStyle = "MyPageSpan"
|
||||
page.SelectStyle = "MyPageSelect"
|
||||
gtest.Assert(page.NextPage(), `<a class="MyPageLink" href="/user/list/3" title="">></a>`)
|
||||
gtest.Assert(page.PrevPage(), `<a class="MyPageLink" href="/user/list/1" title=""><</a>`)
|
||||
gtest.Assert(page.FirstPage(), `<a class="MyPageLink" href="/user/list/1" title="">|<</a>`)
|
||||
gtest.Assert(page.LastPage(), `<a class="MyPageLink" href="/user/list/5" title="">>|</a>`)
|
||||
gtest.Assert(page.PageBar(), `<a class="MyPageLink" href="/user/list/1" title="1">1</a><span class="MyPageSpan">2</span><a class="MyPageLink" href="/user/list/3" title="3">3</a><a class="MyPageLink" href="/user/list/4" title="4">4</a><a class="MyPageLink" href="/user/list/5" title="5">5</a>`)
|
||||
gtest.Assert(page.SelectBar(), `<select name="MyPageSelect" onchange="window.location.href=this.value"><option value="/user/list/1">1</option><option value="/user/list/2" selected>2</option><option value="/user/list/3">3</option><option value="/user/list/4">4</option><option value="/user/list/5">5</option></select>`)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Ajax(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
|
||||
page.AjaxActionName = "LoadPage"
|
||||
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">></a>`)
|
||||
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title=""><</a>`)
|
||||
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">|<</a>`)
|
||||
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">>|</a>`)
|
||||
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a>`)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_PredefinedContent(t *testing.T) {
|
||||
gtest.Case(t, func() {
|
||||
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
|
||||
page.AjaxActionName = "LoadPage"
|
||||
gtest.Assert(page.GetContent(1), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a> <span class="current">2</span> <a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a>`)
|
||||
gtest.Assert(page.GetContent(2), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title=""><<上一页</a><span class="current">[第2页]</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页>></a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a>第<select name="GPageSelect" onchange="window.location.href=this.value"><option value="/user/list/1">1</option><option value="/user/list/2" selected>2</option><option value="/user/list/3">3</option><option value="/user/list/4">4</option><option value="/user/list/5">5</option></select>页`)
|
||||
gtest.Assert(page.GetContent(3), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a><span>当前页2/5</span> <span>共5条</span>`)
|
||||
gtest.Assert(page.GetContent(4), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a>`)
|
||||
gtest.Assert(page.GetContent(5), ``)
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user