Merge pull request #2 from gogf/master

update
This commit is contained in:
wenzi 2020-03-09 23:35:59 +08:00 committed by GitHub
commit 24ea9f9245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
84 changed files with 2348 additions and 1083 deletions

View File

@ -0,0 +1,119 @@
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package driver
import (
"database/sql"
"fmt"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
)
type MyDriver struct {
*gdb.Core
}
// Open creates and returns a underlying sql.DB object for mysql.
func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
} else {
source = fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local",
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset,
)
}
intlog.Printf("Open: %s", source)
if db, err := sql.Open("mysql", source); err == nil {
return db, nil
} else {
return nil, err
}
}
// getChars returns the security char for this type of database.
func (d *MyDriver) GetChars() (charLeft string, charRight string) {
return "`", "`"
}
// handleSqlBeforeExec handles the sql before posts it to database.
func (d *MyDriver) HandleSqlBeforeExec(sql string) string {
return sql
}
// Tables retrieves and returns the tables of current schema.
func (d *MyDriver) Tables(schema ...string) (tables []string, err error) {
var result gdb.Result
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
if err != nil {
return
}
for _, m := range result {
for _, v := range m {
tables = append(tables, v.String())
}
}
return
}
// gdb.TableFields retrieves and returns the fields information of specified table of current schema.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function gdb.TableFields supports only single table operations")
}
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
func() interface{} {
var result gdb.Result
var link *sql.DB
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = d.DB.DoGetAll(
link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
)
if err != nil {
return nil
}
fields = make(map[string]*gdb.TableField)
for i, m := range result {
fields[m["Field"].String()] = &gdb.TableField{
Index: i,
Name: m["Field"].String(),
Type: m["Type"].String(),
Null: m["Null"].Bool(),
Key: m["Key"].String(),
Default: m["Default"].Val(),
Extra: m["Extra"].String(),
Comment: m["Comment"].String(),
}
}
return fields
}, 0)
if err == nil {
fields = v.(map[string]*gdb.TableField)
}
return
}

View File

@ -0,0 +1 @@
package main

View File

@ -11,11 +11,11 @@ func main() {
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)
r, e := db.Table("test").OrderBy("id asc").All()
r, e := db.Table("test").Order("id asc").All()
if e != nil {
panic(e)
fmt.Println(e)
}
if r != nil {
fmt.Println(r.ToList())
fmt.Println(r.List())
}
}

View File

@ -9,5 +9,4 @@ func main() {
db.SetDebug(true)
db.Table("user").Fields("DISTINCT id,nickname").Filter().All()
}

View File

@ -1,37 +1,16 @@
package main
import (
"net/http"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"fmt"
"github.com/gogf/gf/text/gregex"
)
func MiddlewareAuth(r *ghttp.Request) {
token := r.Get("token")
if token == "123456" {
r.Response.Writeln("auth")
r.Middleware.Next()
} else {
r.Response.WriteStatus(http.StatusForbidden)
}
}
func MiddlewareCORS(r *ghttp.Request) {
r.Response.Writeln("cors")
r.Response.CORSDefault()
r.Middleware.Next()
}
func main() {
s := g.Server()
s.Use(MiddlewareCORS)
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
group.Middleware(MiddlewareAuth)
group.ALL("/user/list", func(r *ghttp.Request) {
r.Response.Writeln("list")
})
data := "@var(.prefix)您收到的验证码为:@var(.code),请在@var(.expire)内完成验证"
result, err := gregex.ReplaceStringFuncMatch(`(@var\(\.\w+\))`, data, func(match []string) string {
fmt.Println(match)
return "#"
})
s.SetPort(8199)
s.Run()
fmt.Println(err)
fmt.Println(result)
}

View File

@ -4,13 +4,12 @@ import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/gpage"
)
func main() {
s := ghttp.GetServer()
s.BindHandler("/page/demo", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String())
page := r.GetPage(100, 10)
buffer, _ := gview.ParseContent(`
<html>
<head>

View File

@ -4,14 +4,13 @@ import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/gpage"
)
func main() {
s := ghttp.GetServer()
s.BindHandler("/page/ajax", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
page.EnableAjax("DoAjax")
page := r.GetPage(100, 10)
page.AjaxActionName = "DoAjax"
buffer, _ := gview.ParseContent(`
<html>
<head>
@ -29,11 +28,17 @@ func main() {
</script>
</head>
<body>
<div>{{.page}}</div>
<div>{{.page1}}</div>
<div>{{.page2}}</div>
<div>{{.page3}}</div>
<div>{{.page4}}</div>
</body>
</html>
`, g.Map{
"page": page.GetContent(1),
"page1": page.GetContent(1),
"page2": page.GetContent(2),
"page3": page.GetContent(3),
"page4": page.GetContent(4),
})
r.Response.Write(buffer)
})

View File

@ -23,7 +23,7 @@ func wrapContent(page *gpage.Page) string {
func main() {
s := ghttp.GetServer()
s.BindHandler("/page/custom1/*page", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
page := r.GetPage(100, 10)
content := wrapContent(page)
buffer, _ := gview.ParseContent(`
<html>

View File

@ -15,7 +15,7 @@ func pageContent(page *gpage.Page) string {
page.LastPageTag = "LastPage"
pageStr := page.FirstPage()
pageStr += page.PrevPage()
pageStr += page.PageBar("current-page")
pageStr += page.PageBar()
pageStr += page.NextPage()
pageStr += page.LastPage()
return pageStr
@ -24,7 +24,7 @@ func pageContent(page *gpage.Page) string {
func main() {
s := ghttp.GetServer()
s.BindHandler("/page/custom2/*page", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
page := r.GetPage(100, 10)
buffer, _ := gview.ParseContent(`
<html>
<head>

View File

@ -4,13 +4,12 @@ import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/gpage"
)
func main() {
s := g.Server()
s.BindHandler("/page/static/*page", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
page := r.GetPage(100, 10)
buffer, _ := gview.ParseContent(`
<html>
<head>

View File

@ -4,13 +4,12 @@ import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/gpage"
)
func main() {
s := g.Server()
s.BindHandler("/:obj/*action/{page}.html", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router)
page := r.GetPage(100, 10)
buffer, _ := gview.ParseContent(`
<html>
<head>

View File

@ -4,14 +4,13 @@ import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/gpage"
)
func main() {
s := g.Server()
s.BindHandler("/page/template/{page}.html", func(r *ghttp.Request) {
page := gpage.New(100, 10, r.Get("page"), r.URL.String())
page.SetUrlTemplate("/order/list/{.page}.html")
page := r.GetPage(100, 10)
page.UrlTemplate = "/order/list/{.page}.html"
buffer, _ := gview.ParseContent(`
<html>
<head>

View File

@ -15,7 +15,7 @@ We currently accept donation by Alipay/WechatPay, please note your github/gitee
|[zhuhuan12](https://gitee.com/zhuhuan12)|gitee|¥50.00 |
|[zfan_codes](https://gitee.com/zfan_codes)|gitee|¥10.00 |
|[arden](https://github.com/arden)|alipay|¥10.00 |
|[macnie](https://www.macnie.com)|wechat|¥100.00 |
|[macnie](https://www.macnie.com)|wechat|¥110.00 |
|lah|wechat|¥100.00 |
|x*z|wechat|¥20.00 |
|潘兄|wechat|¥100.00 |

View File

@ -439,6 +439,9 @@ func (a *SortedArray) SetUnique(unique bool) *SortedArray {
func (a *SortedArray) Unique() *SortedArray {
a.mu.Lock()
defer a.mu.Unlock()
if len(a.array) == 0 {
return a
}
i := 0
for {
if i == len(a.array)-1 {

View File

@ -429,6 +429,10 @@ func (a *SortedIntArray) SetUnique(unique bool) *SortedIntArray {
// Unique uniques the array, clear repeated items.
func (a *SortedIntArray) Unique() *SortedIntArray {
a.mu.Lock()
defer a.mu.Unlock()
if len(a.array) == 0 {
return a
}
i := 0
for {
if i == len(a.array)-1 {
@ -440,7 +444,6 @@ func (a *SortedIntArray) Unique() *SortedIntArray {
i++
}
}
a.mu.Unlock()
return a
}

View File

@ -414,6 +414,10 @@ func (a *SortedStrArray) SetUnique(unique bool) *SortedStrArray {
// Unique uniques the array, clear repeated items.
func (a *SortedStrArray) Unique() *SortedStrArray {
a.mu.Lock()
defer a.mu.Unlock()
if len(a.array) == 0 {
return a
}
i := 0
for {
if i == len(a.array)-1 {
@ -425,7 +429,6 @@ func (a *SortedStrArray) Unique() *SortedStrArray {
i++
}
}
a.mu.Unlock()
return a
}

View File

@ -13,8 +13,8 @@ import (
)
func Example_basic() {
// 创建普通的数组,默认并发安全(带锁)
a := garray.New(true)
// 创建普通的数组
a := garray.New()
// 添加数据项
for i := 0; i < 10; i++ {

View File

@ -4,7 +4,7 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package gqueue provides a dynamic/static concurrent-safe queue.
// Package gqueue provides dynamic/static concurrent-safe queue.
//
// Features:
//
@ -25,6 +25,7 @@ import (
"github.com/gogf/gf/container/gtype"
)
// Queue is a concurrent-safe queue built on doubly linked list and channel.
type Queue struct {
limit int // Limit for queue size.
list *glist.List // Underlying list structure for data maintaining.
@ -54,14 +55,14 @@ func New(limit ...int) *Queue {
q.list = glist.New(true)
q.events = make(chan struct{}, math.MaxInt32)
q.C = make(chan interface{}, gDEFAULT_QUEUE_SIZE)
go q.startAsyncLoop()
go q.asyncLoopFromListToChannel()
}
return q
}
// startAsyncLoop starts an asynchronous goroutine,
// asyncLoopFromListToChannel starts an asynchronous goroutine,
// which handles the data synchronization from list <q.list> to channel <q.C>.
func (q *Queue) startAsyncLoop() {
func (q *Queue) asyncLoopFromListToChannel() {
defer func() {
if q.closed.Val() {
_ = recover()

View File

@ -22,7 +22,7 @@ import (
"github.com/gogf/gf/util/grand"
)
// DB is the interface for ORM operations.
// DB defines the interfaces for ORM operations.
type DB interface {
// Open creates a raw connection object for database with given node configuration.
// Note that it is not recommended using the this function manually.
@ -34,14 +34,14 @@ type DB interface {
Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error)
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error)
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
doPrepare(link dbLink, query string) (*sql.Stmt, error)
doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error)
DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error)
DoGetAll(link Link, query string, args ...interface{}) (result Result, err error)
DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error)
DoPrepare(link Link, query string) (*sql.Stmt, error)
DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error)
// Query APIs for convenience purpose.
GetAll(query string, args ...interface{}) (Result, error)
@ -52,11 +52,11 @@ type DB interface {
GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error
GetScan(objPointer interface{}, query string, args ...interface{}) error
// Master/Slave support.
// Master/Slave specification support.
Master() (*sql.DB, error)
Slave() (*sql.DB, error)
// Ping.
// Ping-Pong.
PingMaster() error
PingSlave() error
@ -75,48 +75,49 @@ type DB interface {
Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error)
Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error)
// Create model.
// Model creation.
From(tables string) *Model
Table(tables string) *Model
Schema(schema string) *Schema
// Configuration methods.
GetCache() *gcache.Cache
SetDebug(debug bool)
GetDebug() bool
SetSchema(schema string)
GetSchema() string
GetPrefix() string
SetLogger(logger *glog.Logger)
GetLogger() *glog.Logger
SetMaxIdleConnCount(n int)
SetMaxOpenConnCount(n int)
SetMaxConnLifetime(d time.Duration)
// Utility methods.
GetChars() (charLeft string, charRight string)
GetMaster(schema ...string) (*sql.DB, error)
GetSlave(schema ...string) (*sql.DB, error)
QuoteWord(s string) string
QuoteString(s string) string
QuotePrefixTableName(table string) string
Tables(schema ...string) (tables []string, err error)
TableFields(table string, schema ...string) (map[string]*TableField, error)
// HandleSqlBeforeCommit is a hook function, which deals with the sql string before
// it's committed to underlying driver. The parameter <link> specifies the current
// database connection operation object. You can modify the sql string <query> and its
// arguments <args> as you wish before they're committed to driver.
HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{})
// Internal methods.
getCache() *gcache.Cache
getChars() (charLeft string, charRight string)
getDebug() bool
getPrefix() string
getMaster(schema ...string) (*sql.DB, error)
getSlave(schema ...string) (*sql.DB, error)
quoteWord(s string) string
quoteString(s string) string
handleTableName(table string) string
filterFields(schema, table string, data map[string]interface{}) map[string]interface{}
convertValue(fieldValue []byte, fieldType string) interface{}
rowsToResult(rows *sql.Rows) (Result, error)
handleSqlBeforeExec(sql string) string
}
// dbLink is a common database function wrapper interface for internal usage.
type dbLink interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
// dbBase is the base struct for database management.
type dbBase struct {
db DB // DB interface object.
// Core is the base struct for database management.
type Core struct {
DB DB // DB interface object.
group string // Configuration group name.
debug *gtype.Bool // Enable debug mode for the database.
cache *gcache.Cache // Cache manager.
@ -128,6 +129,12 @@ type dbBase struct {
maxConnLifetime time.Duration // Max TTL for a connection.
}
// Driver is the interface for integrating sql drivers into package gdb.
type Driver interface {
// New creates and returns a database object for specified database server.
New(core *Core, node *ConfigNode) (DB, error)
}
// Sql is the sql recording struct.
type Sql struct {
Sql string // SQL string(may contain reserved char '?').
@ -150,6 +157,13 @@ type TableField struct {
Comment string // Comment.
}
// Link is a common database function wrapper interface.
type Link interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
// Value is the field value type.
type Value = *gvar.Var
@ -176,10 +190,24 @@ const (
)
var (
// Instance map.
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)
// driverMap manages all custom registered driver.
driverMap = map[string]Driver{
"mysql": &DriverMysql{},
"mssql": &DriverMssql{},
"pgsql": &DriverPgsql{},
"oracle": &DriverOracle{},
"sqlite": &DriverSqlite{},
}
)
// Register registers custom database driver to gdb.
func Register(name string, driver Driver) error {
driverMap[name] = driver
return nil
}
// New creates and returns an ORM object with global configurations.
// The parameter <name> specifies the configuration group name,
// which is DEFAULT_GROUP_NAME in default.
@ -196,31 +224,24 @@ func New(name ...string) (db DB, err error) {
}
if _, ok := configs.config[group]; ok {
if node, err := getConfigNodeByGroup(group, true); err == nil {
base := &dbBase{
group: group,
debug: gtype.NewBool(),
cache: gcache.New(),
schema: gtype.NewString(),
logger: glog.New(),
prefix: node.Prefix,
// Default max connection life time if user does not configure.
maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME,
c := &Core{
group: group,
debug: gtype.NewBool(),
cache: gcache.New(),
schema: gtype.NewString(),
logger: glog.New(),
prefix: node.Prefix,
maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME, // Default max connection life time if user does not configure.
}
switch node.Type {
case "mysql":
base.db = &dbMysql{dbBase: base}
case "pgsql":
base.db = &dbPgsql{dbBase: base}
case "mssql":
base.db = &dbMssql{dbBase: base}
case "sqlite":
base.db = &dbSqlite{dbBase: base}
case "oracle":
base.db = &dbOracle{dbBase: base}
default:
if v, ok := driverMap[node.Type]; ok {
c.DB, err = v.New(c, node)
if err != nil {
return nil, err
}
return c.DB, nil
} else {
return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type))
}
return base.db, nil
} else {
return nil, err
}
@ -321,9 +342,9 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
// getSqlDb retrieves and returns a underlying database connection object.
// The parameter <master> specifies whether retrieves master node connection if
// master-slave nodes are configured.
func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) {
// Load balance.
node, err := getConfigNodeByGroup(bs.group, master)
node, err := getConfigNodeByGroup(c.group, master)
if err != nil {
return nil, err
}
@ -332,7 +353,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
node.Charset = "utf8"
}
// Changes the schema.
nodeSchema := bs.schema.Val()
nodeSchema := c.schema.Val()
if len(schema) > 0 && schema[0] != "" {
nodeSchema = schema[0]
}
@ -343,25 +364,25 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
node = &n
}
// Cache the underlying connection object by node.
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = bs.db.Open(node)
v := c.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = c.DB.Open(node)
if err != nil {
return nil
}
if bs.maxIdleConnCount > 0 {
sqlDb.SetMaxIdleConns(bs.maxIdleConnCount)
if c.maxIdleConnCount > 0 {
sqlDb.SetMaxIdleConns(c.maxIdleConnCount)
} else if node.MaxIdleConnCount > 0 {
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
}
if bs.maxOpenConnCount > 0 {
sqlDb.SetMaxOpenConns(bs.maxOpenConnCount)
if c.maxOpenConnCount > 0 {
sqlDb.SetMaxOpenConns(c.maxOpenConnCount)
} else if node.MaxOpenConnCount > 0 {
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
}
if bs.maxConnLifetime > 0 {
sqlDb.SetConnMaxLifetime(bs.maxConnLifetime * time.Second)
if c.maxConnLifetime > 0 {
sqlDb.SetConnMaxLifetime(c.maxConnLifetime * time.Second)
} else if node.MaxConnLifetime > 0 {
sqlDb.SetConnMaxLifetime(node.MaxConnLifetime * time.Second)
}
@ -371,40 +392,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er
sqlDb = v.(*sql.DB)
}
if node.Debug {
bs.db.SetDebug(node.Debug)
c.DB.SetDebug(node.Debug)
}
return
}
// SetSchema changes the schema for this database connection object.
// Importantly note that when schema configuration changed for the database,
// it affects all operations on the database object in the future.
func (bs *dbBase) SetSchema(schema string) {
bs.schema.Set(schema)
}
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (bs *dbBase) Master() (*sql.DB, error) {
return bs.getSqlDb(true, bs.schema.Val())
}
// Slave creates and returns a connection from slave node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (bs *dbBase) Slave() (*sql.DB, error) {
return bs.getSqlDb(false, bs.schema.Val())
}
// getMaster acts like function Master but with additional <schema> parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Master.
func (bs *dbBase) getMaster(schema ...string) (*sql.DB, error) {
return bs.getSqlDb(true, schema...)
}
// getSlave acts like function Slave but with additional <schema> parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Slave.
func (bs *dbBase) getSlave(schema ...string) (*sql.DB, error) {
return bs.getSqlDb(false, schema...)
}

View File

@ -16,7 +16,6 @@ import (
"strings"
"github.com/gogf/gf/container/gvar"
"github.com/gogf/gf/os/gcache"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/util/gconv"
@ -32,22 +31,34 @@ var (
lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`)
)
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Master() (*sql.DB, error) {
return c.getSqlDb(true, c.schema.Val())
}
// Slave creates and returns a connection from slave node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Slave() (*sql.DB, error) {
return c.getSqlDb(false, c.schema.Val())
}
// Query commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data querying.
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := bs.db.Slave()
func (c *Core) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := c.DB.Slave()
if err != nil {
return nil, err
}
return bs.db.doQuery(link, query, args...)
return c.DB.DoQuery(link, query, args...)
}
// doQuery commits the query string and its arguments to underlying driver
// through given link object and returns the execution result.
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
func (c *Core) DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error) {
query, args = formatQuery(query, args)
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
rows, err = link.Query(query, args...)
mTime2 := gtime.TimestampMilli()
@ -59,7 +70,7 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows
Start: mTime1,
End: mTime2,
}
bs.printSql(s)
c.writeSqlToLogger(s)
} else {
rows, err = link.Query(query, args...)
}
@ -73,20 +84,20 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows
// Exec commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data inserting and updating.
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
func (c *Core) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := c.DB.Master()
if err != nil {
return nil, err
}
return bs.db.doExec(link, query, args...)
return c.DB.DoExec(link, query, args...)
}
// doExec commits the query string and its arguments to underlying driver
// through given link object and returns the execution result.
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error) {
query, args = formatQuery(query, args)
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
result, err = link.Exec(query, args...)
mTime2 := gtime.TimestampMilli()
@ -98,7 +109,7 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result
Start: mTime1,
End: mTime2,
}
bs.printSql(s)
c.writeSqlToLogger(s)
} else {
result, err = link.Exec(query, args...)
}
@ -113,50 +124,50 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result
//
// The parameter <execOnMaster> specifies whether executing the sql on master node,
// or else it executes the sql on slave node if master-slave configured.
func (bs *dbBase) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
err := (error)(nil)
link := (dbLink)(nil)
link := (Link)(nil)
if len(execOnMaster) > 0 && execOnMaster[0] {
if link, err = bs.db.Master(); err != nil {
if link, err = c.DB.Master(); err != nil {
return nil, err
}
} else {
if link, err = bs.db.Slave(); err != nil {
if link, err = c.DB.Slave(); err != nil {
return nil, err
}
}
return bs.db.doPrepare(link, query)
return c.DB.DoPrepare(link, query)
}
// doPrepare calls prepare function on given link object and returns the statement object.
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
func (c *Core) DoPrepare(link Link, query string) (*sql.Stmt, error) {
return link.Prepare(query)
}
// GetAll queries and returns data records from database.
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
return bs.db.doGetAll(nil, query, args...)
func (c *Core) GetAll(query string, args ...interface{}) (Result, error) {
return c.DB.DoGetAll(nil, query, args...)
}
// doGetAll queries and returns data records from database.
func (bs *dbBase) doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) {
func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Result, err error) {
if link == nil {
link, err = bs.db.Slave()
link, err = c.DB.Slave()
if err != nil {
return nil, err
}
}
rows, err := bs.doQuery(link, query, args...)
rows, err := c.DB.DoQuery(link, query, args...)
if err != nil || rows == nil {
return nil, err
}
defer rows.Close()
return bs.db.rowsToResult(rows)
return c.DB.rowsToResult(rows)
}
// GetOne queries and returns one record from database.
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
list, err := bs.GetAll(query, args...)
func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
list, err := c.DB.GetAll(query, args...)
if err != nil {
return nil, err
}
@ -168,8 +179,8 @@ func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
// GetStruct queries one record from database and converts it to given struct.
// The parameter <pointer> should be a pointer to struct.
func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface{}) error {
one, err := bs.GetOne(query, args...)
func (c *Core) GetStruct(pointer interface{}, query string, args ...interface{}) error {
one, err := c.DB.GetOne(query, args...)
if err != nil {
return err
}
@ -181,8 +192,8 @@ func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface
// GetStructs queries records from database and converts them to given struct.
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interface{}) error {
all, err := bs.GetAll(query, args...)
func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}) error {
all, err := c.DB.GetAll(query, args...)
if err != nil {
return err
}
@ -198,7 +209,7 @@ func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interfac
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
// for conversion.
func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}) error {
func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) error {
t := reflect.TypeOf(pointer)
k := t.Kind()
if k != reflect.Ptr {
@ -207,9 +218,9 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return bs.db.GetStructs(pointer, query, args...)
return c.DB.GetStructs(pointer, query, args...)
case reflect.Struct:
return bs.db.GetStruct(pointer, query, args...)
return c.DB.GetStruct(pointer, query, args...)
}
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}
@ -217,8 +228,8 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}
// GetValue queries and returns the field value from database.
// The sql should queries only one field from database, or else it returns only one
// field of the result.
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
one, err := bs.GetOne(query, args...)
func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
one, err := c.DB.GetOne(query, args...)
if err != nil {
return nil, err
}
@ -229,13 +240,13 @@ func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
}
// GetCount queries and returns the count from database.
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
func (c *Core) GetCount(query string, args ...interface{}) (int, error) {
// If the query fields do not contains function "COUNT",
// it replaces the query string and adds the "COUNT" function to the fields.
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := bs.GetValue(query, args...)
value, err := c.DB.GetValue(query, args...)
if err != nil {
return 0, err
}
@ -243,8 +254,8 @@ func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
}
// PingMaster pings the master node to check authentication or keeps the connection alive.
func (bs *dbBase) PingMaster() error {
if master, err := bs.db.Master(); err != nil {
func (c *Core) PingMaster() error {
if master, err := c.DB.Master(); err != nil {
return err
} else {
return master.Ping()
@ -252,8 +263,8 @@ func (bs *dbBase) PingMaster() error {
}
// PingSlave pings the slave node to check authentication or keeps the connection alive.
func (bs *dbBase) PingSlave() error {
if slave, err := bs.db.Slave(); err != nil {
func (c *Core) PingSlave() error {
if slave, err := c.DB.Slave(); err != nil {
return err
} else {
return slave.Ping()
@ -264,13 +275,13 @@ func (bs *dbBase) PingSlave() error {
// You should call Commit or Rollback functions of the transaction object
// if you no longer use the transaction. Commit or Rollback functions will also
// close the transaction automatically.
func (bs *dbBase) Begin() (*TX, error) {
if master, err := bs.db.Master(); err != nil {
func (c *Core) Begin() (*TX, error) {
if master, err := c.DB.Master(); err != nil {
return nil, err
} else {
if tx, err := master.Begin(); err == nil {
return &TX{
db: bs.db,
db: c.DB,
tx: tx,
master: master,
}, nil
@ -289,8 +300,8 @@ func (bs *dbBase) Begin() (*TX, error) {
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter <batch> specifies the batch operation count when given data is slice.
func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...)
func (c *Core) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...)
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
@ -302,8 +313,8 @@ func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Resu
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter <batch> specifies the batch operation count when given data is slice.
func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...)
func (c *Core) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...)
}
// Replace does "REPLACE INTO ..." statement for the table.
@ -318,8 +329,8 @@ func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sq
// The parameter <data> can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// If given data is type of slice, it then does batch replacing, and the optional parameter
// <batch> specifies the batch operation count.
func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...)
func (c *Core) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...)
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
@ -333,54 +344,60 @@ func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Res
//
// If given data is type of slice, it then does batch saving, and the optional parameter
// <batch> specifies the batch operation count.
func (bs *dbBase) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...)
func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...)
}
// doInsert inserts or updates data for given table.
//
// The parameter <data> can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
// Data(g.Map{"uid": 10000, "name":"john"})
// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"})
//
// The parameter <option> values are as follows:
// 0: insert: just insert, if there's unique/primary key in the data, it returns error;
// 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
// 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one;
// 3: ignore: if there's unique/primary key in the data, it ignores the inserting;
func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
func (c *Core) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
var fields []string
var values []string
var params []interface{}
var dataMap Map
table = bs.db.handleTableName(table)
rv := reflect.ValueOf(data)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
table = c.DB.QuotePrefixTableName(table)
reflectValue := reflect.ValueOf(data)
reflectKind := reflectValue.Kind()
if reflectKind == reflect.Ptr {
reflectValue = reflectValue.Elem()
reflectKind = reflectValue.Kind()
}
switch kind {
switch reflectKind {
case reflect.Slice, reflect.Array:
return bs.db.doBatchInsert(link, table, data, option, batch...)
return c.DB.DoBatchInsert(link, table, data, option, batch...)
case reflect.Map, reflect.Struct:
dataMap = varToMapDeep(data)
dataMap = DataToMapDeep(data)
default:
return result, errors.New(fmt.Sprint("unsupported data type:", kind))
return result, errors.New(fmt.Sprint("unsupported data type:", reflectKind))
}
if len(dataMap) == 0 {
return nil, errors.New("data cannot be empty")
}
charL, charR := bs.db.getChars()
charL, charR := c.DB.GetChars()
for k, v := range dataMap {
fields = append(fields, charL+k+charR)
values = append(values, "?")
params = append(params, v)
}
operation := getInsertOperationByOption(option)
operation := GetInsertOperationByOption(option)
updateStr := ""
if option == gINSERT_OPTION_SAVE {
for k, _ := range dataMap {
if len(updateStr) > 0 {
updateStr += ","
}
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
updateStr += fmt.Sprintf(
"%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
)
@ -388,45 +405,50 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr)
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
if link, err = c.DB.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updateStr),
params...)
return c.DB.DoExec(
link,
fmt.Sprintf(
"%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updateStr,
),
params...,
)
}
// BatchInsert batch inserts data.
// The parameter <list> must be type of slice of map or struct.
func (bs *dbBase) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...)
func (c *Core) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...)
}
// BatchInsert batch inserts data with ignore option.
// The parameter <list> must be type of slice of map or struct.
func (bs *dbBase) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...)
func (c *Core) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...)
}
// BatchReplace batch replaces data.
// The parameter <list> must be type of slice of map or struct.
func (bs *dbBase) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...)
func (c *Core) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...)
}
// BatchSave batch replaces data.
// The parameter <list> must be type of slice of map or struct.
func (bs *dbBase) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...)
func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...)
}
// doBatchInsert batch inserts/replaces/saves data.
func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
var keys, values []string
var params []interface{}
table = bs.db.handleTableName(table)
table = c.DB.QuotePrefixTableName(table)
listMap := (List)(nil)
switch v := list.(type) {
case Result:
@ -449,10 +471,10 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
case reflect.Slice, reflect.Array:
listMap = make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
listMap[i] = varToMapDeep(rv.Index(i).Interface())
listMap[i] = DataToMapDeep(rv.Index(i).Interface())
}
case reflect.Map, reflect.Struct:
listMap = List{varToMapDeep(list)}
listMap = List{DataToMapDeep(list)}
default:
return result, errors.New(fmt.Sprint("unsupported list type:", kind))
}
@ -461,7 +483,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
return result, errors.New("data list cannot be empty")
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
if link, err = c.DB.Master(); err != nil {
return
}
}
@ -471,20 +493,21 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
keys = append(keys, k)
holders = append(holders, "?")
}
// Prepare the result pointer.
// Prepare the batch result pointer.
batchResult := new(batchSqlResult)
charL, charR := bs.db.getChars()
charL, charR := c.DB.GetChars()
keysStr := charL + strings.Join(keys, charR+","+charL) + charR
valueHolderStr := "(" + strings.Join(holders, ",") + ")"
operation := getInsertOperationByOption(option)
operation := GetInsertOperationByOption(option)
updateStr := ""
if option == gINSERT_OPTION_SAVE {
for _, k := range keys {
if len(updateStr) > 0 {
updateStr += ","
}
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
updateStr += fmt.Sprintf(
"%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
)
@ -504,7 +527,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
}
values = append(values, valueHolderStr)
if len(values) == batchNum || (i == listMapLen-1 && len(values) > 0) {
r, err := bs.db.doExec(
r, err := c.DB.DoExec(
link,
fmt.Sprintf(
"%s INTO %s(%s) VALUES%s %s",
@ -546,18 +569,18 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatWhere(bs.db, condition, args, false)
func (c *Core) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatWhere(c.DB, condition, args, false)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return bs.db.doUpdate(nil, table, data, newWhere, newArgs...)
return c.DB.DoUpdate(nil, table, data, newWhere, newArgs...)
}
// doUpdate does "UPDATE ... " statement for the table.
// Also see Update.
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = bs.db.handleTableName(table)
func (c *Core) DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = c.DB.QuotePrefixTableName(table)
updates := ""
rv := reflect.ValueOf(data)
kind := rv.Kind()
@ -569,8 +592,8 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
switch kind {
case reflect.Map, reflect.Struct:
var fields []string
for k, v := range varToMapDeep(data) {
fields = append(fields, bs.db.quoteWord(k)+"=?")
for k, v := range DataToMapDeep(data) {
fields = append(fields, c.DB.QuoteWord(k)+"=?")
params = append(params, v)
}
updates = strings.Join(fields, ",")
@ -585,11 +608,15 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
}
// If no link passed, it then uses the master link.
if link == nil {
if link, err = bs.db.Master(); err != nil {
if link, err = c.DB.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...)
return c.DB.DoExec(
link,
fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition),
args...,
)
}
// Delete does "DELETE FROM ... " statement for the table.
@ -603,38 +630,28 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
// "status IN (?)", g.Slice{1,2,3}
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatWhere(bs.db, condition, args, false)
func (c *Core) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatWhere(c.DB, condition, args, false)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return bs.db.doDelete(nil, table, newWhere, newArgs...)
return c.DB.DoDelete(nil, table, newWhere, newArgs...)
}
// doDelete does "DELETE FROM ... " statement for the table.
// Also see Delete.
func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) {
func (c *Core) DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) {
if link == nil {
if link, err = bs.db.Master(); err != nil {
if link, err = c.DB.Master(); err != nil {
return nil, err
}
}
table = bs.db.handleTableName(table)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
// getCache returns the internal cache object.
func (bs *dbBase) getCache() *gcache.Cache {
return bs.cache
}
// getPrefix returns the table prefix string configured.
func (bs *dbBase) getPrefix() string {
return bs.prefix
table = c.DB.QuotePrefixTableName(table)
return c.DB.DoExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
// rowsToResult converts underlying data record type sql.Rows to Result type.
func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
func (c *Core) rowsToResult(rows *sql.Rows) (Result, error) {
if !rows.Next() {
return nil, nil
}
@ -671,7 +688,7 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
// it should do a copy of it.
v := make([]byte, len(value))
copy(v, value)
row[columnNames[i]] = gvar.New(bs.db.convertValue(v, columnTypes[i]))
row[columnNames[i]] = gvar.New(c.DB.convertValue(v, columnTypes[i]))
}
}
records = append(records, row)
@ -682,39 +699,23 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
return records, nil
}
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut".
// MarshalJSON implements the interface MarshalJSON for json.Marshal.
// It just returns the pointer address.
//
// Note that, this will automatically checks the table prefix whether already added, if true it does
// nothing to the table name, or else adds the prefix to the table name.
func (bs *dbBase) handleTableName(table string) string {
charLeft, charRight := bs.db.getChars()
prefix := bs.db.getPrefix()
return doHandleTableName(table, prefix, charLeft, charRight)
// Note that this interface implements mainly for workaround for a json infinite loop bug
// of Golang version < v1.14.
func (c *Core) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`%+v`, c)), nil
}
// quoteWord checks given string <s> a word, if true quotes it with security chars of the database
// and returns the quoted string; or else return <s> without any change.
func (bs *dbBase) quoteWord(s string) string {
charLeft, charRight := bs.db.getChars()
return doQuoteWord(s, charLeft, charRight)
}
// quoteString quotes string with quote chars. Strings like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
func (bs *dbBase) quoteString(s string) string {
charLeft, charRight := bs.db.getChars()
return doQuoteString(s, charLeft, charRight)
}
// printSql outputs the sql object to logger.
// writeSqlToLogger outputs the sql object to logger.
// It is enabled when configuration "debug" is true.
func (bs *dbBase) printSql(v *Sql) {
func (c *Core) writeSqlToLogger(v *Sql) {
s := fmt.Sprintf("[%d ms] %s", v.End-v.Start, v.Format)
if v.Error != nil {
s += "\nError: " + v.Error.Error()
bs.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s)
c.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s)
} else {
bs.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s)
c.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s)
}
}

View File

@ -8,6 +8,7 @@ package gdb
import (
"fmt"
"github.com/gogf/gf/os/gcache"
"sync"
"time"
@ -114,29 +115,29 @@ func GetDefaultGroup() string {
}
// SetLogger sets the logger for orm.
func (bs *dbBase) SetLogger(logger *glog.Logger) {
bs.logger = logger
func (c *Core) SetLogger(logger *glog.Logger) {
c.logger = logger
}
// GetLogger returns the logger of the orm.
func (bs *dbBase) GetLogger() *glog.Logger {
return bs.logger
func (c *Core) GetLogger() *glog.Logger {
return c.logger
}
// SetMaxIdleConnCount sets the max idle connection count for underlying connection pool.
func (bs *dbBase) SetMaxIdleConnCount(n int) {
bs.maxIdleConnCount = n
func (c *Core) SetMaxIdleConnCount(n int) {
c.maxIdleConnCount = n
}
// SetMaxOpenConnCount sets the max open connection count for underlying connection pool.
func (bs *dbBase) SetMaxOpenConnCount(n int) {
bs.maxOpenConnCount = n
func (c *Core) SetMaxOpenConnCount(n int) {
c.maxOpenConnCount = n
}
// SetMaxConnLifetime sets the connection TTL for underlying connection pool.
// If parameter <d> <= 0, it means the connection never expires.
func (bs *dbBase) SetMaxConnLifetime(d time.Duration) {
bs.maxConnLifetime = d
func (c *Core) SetMaxConnLifetime(d time.Duration) {
c.maxConnLifetime = d
}
// String returns the node as string.
@ -155,14 +156,36 @@ func (node *ConfigNode) String() string {
}
// SetDebug enables/disables the debug mode.
func (bs *dbBase) SetDebug(debug bool) {
if bs.debug.Val() == debug {
func (c *Core) SetDebug(debug bool) {
if c.debug.Val() == debug {
return
}
bs.debug.Set(debug)
c.debug.Set(debug)
}
// getDebug returns the debug value.
func (bs *dbBase) getDebug() bool {
return bs.debug.Val()
// GetDebug returns the debug value.
func (c *Core) GetDebug() bool {
return c.debug.Val()
}
// GetCache returns the internal cache object.
func (c *Core) GetCache() *gcache.Cache {
return c.cache
}
// GetPrefix returns the table prefix string configured.
func (c *Core) GetPrefix() string {
return c.prefix
}
// SetSchema changes the schema for this database connection object.
// Importantly note that when schema configuration changed for the database,
// it affects all operations on the database object in the future.
func (c *Core) SetSchema(schema string) {
c.schema.Set(schema)
}
// GetSchema returns the schema configured.
func (c *Core) GetSchema() string {
return c.schema.Val()
}

View File

@ -0,0 +1,86 @@
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
//
package gdb
import (
"database/sql"
)
// GetMaster acts like function Master but with additional <schema> parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Master.
func (c *Core) GetMaster(schema ...string) (*sql.DB, error) {
return c.getSqlDb(true, schema...)
}
// GetSlave acts like function Slave but with additional <schema> parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Slave.
func (c *Core) GetSlave(schema ...string) (*sql.DB, error) {
return c.getSqlDb(false, schema...)
}
// QuoteWord checks given string <s> a word, if true quotes it with security chars of the database
// and returns the quoted string; or else return <s> without any change.
func (c *Core) QuoteWord(s string) string {
charLeft, charRight := c.DB.GetChars()
return doQuoteWord(s, charLeft, charRight)
}
// QuoteString quotes string with quote chars. Strings like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
func (c *Core) QuoteString(s string) string {
charLeft, charRight := c.DB.GetChars()
return doQuoteString(s, charLeft, charRight)
}
// QuotePrefixTableName adds prefix string and quotes chars for the table.
// It handles table string like:
// "user", "user u",
// "user,user_detail",
// "user u, user_detail ut",
// "user as u, user_detail as ut".
//
// Note that, this will automatically checks the table prefix whether already added,
// if true it does nothing to the table name, or else adds the prefix to the table name.
func (c *Core) QuotePrefixTableName(table string) string {
charLeft, charRight := c.DB.GetChars()
return doHandleTableName(table, c.DB.GetPrefix(), charLeft, charRight)
}
// GetChars returns the security char for current database.
// It does nothing in default.
func (c *Core) GetChars() (charLeft string, charRight string) {
return "", ""
}
// HandleSqlBeforeCommit handles the sql before posts it to database.
// It does nothing in default.
func (c *Core) HandleSqlBeforeCommit(sql string) string {
return sql
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
//
// It does nothing in default.
func (c *Core) Tables(schema ...string) (tables []string, err error) {
return
}
// TableFields retrieves and returns the fields information of specified table of current schema.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the TableField struct has a "Index" field marks its sequence in the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
//
// It does nothing in default.
func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
return
}

View File

@ -5,7 +5,7 @@
// You can obtain one at https://github.com/gogf/gf.
//
// Note:
// 1. It needs manually import: _ "github.com/lib/pq"
// 1. It needs manually import: _ "github.com/denisenkom/go-mssqldb"
// 2. It does not support Save/Replace features.
// 3. It does not support LastInsertId.
@ -22,12 +22,21 @@ import (
"github.com/gogf/gf/text/gregex"
)
type dbMssql struct {
*dbBase
// DriverMssql is the driver for SQL server database.
type DriverMssql struct {
*Core
}
// New creates and returns a database object for SQL server.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverMssql{
Core: core,
}, nil
}
// Open creates and returns a underlying sql.DB object for mssql.
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) {
source := ""
if config.LinkInfo != "" {
source = config.LinkInfo
@ -45,13 +54,13 @@ func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
}
}
// getChars returns the security char for this type of database.
func (db *dbMssql) getChars() (charLeft string, charRight string) {
// GetChars returns the security char for this type of database.
func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
func (db *dbMssql) handleSqlBeforeExec(query string) string {
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string "@px".
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
@ -59,10 +68,10 @@ func (db *dbMssql) handleSqlBeforeExec(query string) string {
return fmt.Sprintf("@p%d", index)
})
str, _ = gregex.ReplaceString("\"", "", str)
return db.parseSql(str)
return d.parseSql(str), args
}
func (db *dbMssql) parseSql(sql string) string {
func (d *DriverMssql) parseSql(sql string) string {
// SELECT * FROM USER WHERE ID=1 LIMIT 1
if m, _ := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, sql); len(m) > 1 {
return fmt.Sprintf(`SELECT TOP 1 %s`, m[1])
@ -163,44 +172,45 @@ func (db *dbMssql) parseSql(sql string) string {
}
// Tables retrieves and returns the tables of current schema.
func (db *dbMssql) Tables(schema ...string) (tables []string, err error) {
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) {
var result Result
link, err := db.getSlave(schema...)
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = db.doGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
result, err = d.DB.DoGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
if err != nil {
return
}
for _, m := range result {
for _, v := range m {
tables = append(tables, strings.ToLower(v.String()))
tables = append(tables, v.String())
}
}
return
}
// TableFields retrieves and returns the fields information of specified table of current schema.
func (db *dbMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function TableFields supports only single table operations")
}
checkSchema := db.schema.Val()
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := db.cache.GetOrSetFunc(
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`mssql_table_fields_%s_%s`, table, checkSchema), func() interface{} {
var result Result
var link *sql.DB
link, err = db.getSlave(checkSchema)
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = db.doGetAll(link, fmt.Sprintf(`
result, err = d.DB.DoGetAll(link, fmt.Sprintf(`
SELECT c.name as FIELD, CASE t.name
WHEN 'numeric' THEN t.name + '(' + convert(varchar(20),c.xprec) + ',' + convert(varchar(20),c.xscale) + ')'
WHEN 'char' THEN t.name + '(' + convert(varchar(20),c.length)+ ')'

View File

@ -12,15 +12,24 @@ import (
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
_ "github.com/gf-third/mysql"
_ "github.com/go-sql-driver/mysql"
)
type dbMysql struct {
*dbBase
// DriverMysql is the driver for mysql database.
type DriverMysql struct {
*Core
}
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverMysql{
Core: core,
}, nil
}
// Open creates and returns a underlying sql.DB object for mysql.
func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) {
func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
@ -31,31 +40,32 @@ func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) {
)
}
intlog.Printf("Open: %s", source)
if db, err := sql.Open("gf-mysql", source); err == nil {
if db, err := sql.Open("mysql", source); err == nil {
return db, nil
} else {
return nil, err
}
}
// getChars returns the security char for this type of database.
func (db *dbMysql) getChars() (charLeft string, charRight string) {
// GetChars returns the security char for this type of database.
func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
return "`", "`"
}
// handleSqlBeforeExec handles the sql before posts it to database.
func (db *dbMysql) handleSqlBeforeExec(sql string) string {
return sql
// HandleSqlBeforeCommit handles the sql before posts it to database.
func (d *DriverMysql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
}
// Tables retrieves and returns the tables of current schema.
func (bs *dbBase) Tables(schema ...string) (tables []string, err error) {
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) {
var result Result
link, err := bs.db.getSlave(schema...)
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = bs.db.doGetAll(link, `SHOW TABLES`)
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
if err != nil {
return
}
@ -73,27 +83,27 @@ func (bs *dbBase) Tables(schema ...string) (tables []string, err error) {
// As a map is unsorted, the TableField struct has a "Index" field marks its sequence in the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
func (bs *dbBase) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverMysql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function TableFields supports only single table operations")
}
checkSchema := bs.schema.Val()
checkSchema := d.schema.Val()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := bs.cache.GetOrSetFunc(
v := d.cache.GetOrSetFunc(
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
func() interface{} {
var result Result
var link *sql.DB
link, err = bs.db.getSlave(checkSchema)
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = bs.doGetAll(
result, err = d.DB.DoGetAll(
link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, bs.db.quoteWord(table)),
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
)
if err != nil {
return nil

View File

@ -24,8 +24,9 @@ import (
"github.com/gogf/gf/text/gregex"
)
type dbOracle struct {
*dbBase
// DriverOracle is the driver for oracle database.
type DriverOracle struct {
*Core
}
const (
@ -33,8 +34,16 @@ const (
tableAlias2 = "GFORM2"
)
// New creates and returns a database object for oracle.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverOracle{
Core: core,
}, nil
}
// Open creates and returns a underlying sql.DB object for oracle.
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
func (d *DriverOracle) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
@ -49,13 +58,13 @@ func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
}
}
// getChars returns the security char for this type of database.
func (db *dbOracle) getChars() (charLeft string, charRight string) {
// GetChars returns the security char for this type of database.
func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
func (db *dbOracle) handleSqlBeforeExec(query string) string {
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string ":x".
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
@ -63,10 +72,10 @@ func (db *dbOracle) handleSqlBeforeExec(query string) string {
return fmt.Sprintf(":%d", index)
})
str, _ = gregex.ReplaceString("\"", "", str)
return db.parseSql(str)
return d.parseSql(str), args
}
func (db *dbOracle) parseSql(sql string) string {
func (d *DriverOracle) parseSql(sql string) string {
patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
if gregex.IsMatchString(patten, sql) == false {
return sql
@ -123,36 +132,37 @@ func (db *dbOracle) parseSql(sql string) string {
}
// Tables retrieves and returns the tables of current schema.
func (db *dbOracle) Tables(schema ...string) (tables []string, err error) {
// It's mainly used in cli tool chain for automatically generating the models.
// Note that it ignores the parameter <schema> in oracle database, as it is not necessary.
func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) {
var result Result
result, err = db.doGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
result, err = d.DB.DoGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
if err != nil {
return
}
for _, m := range result {
for _, v := range m {
tables = append(tables, strings.ToLower(v.String()))
tables = append(tables, v.String())
}
}
return
}
// TableFields retrieves and returns the fields information of specified table of current schema.
func (db *dbOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function TableFields supports only single table operations")
}
checkSchema := db.schema.Val()
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := db.cache.GetOrSetFunc(
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`oracle_table_fields_%s_%s`, table, checkSchema),
func() interface{} {
result := (Result)(nil)
result, err = db.GetAll(fmt.Sprintf(`
result, err = d.DB.GetAll(fmt.Sprintf(`
SELECT COLUMN_NAME AS FIELD, CASE DATA_TYPE
WHEN 'NUMBER' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')'
WHEN 'FLOAT' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')'
@ -177,11 +187,11 @@ func (db *dbOracle) TableFields(table string, schema ...string) (fields map[stri
return
}
func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) {
func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) {
table = strings.ToUpper(table)
v := db.cache.GetOrSetFunc("table_unique_index_"+table, func() interface{} {
v := d.DB.GetCache().GetOrSetFunc("table_unique_index_"+table, func() interface{} {
res := (Result)(nil)
res, err = db.GetAll(fmt.Sprintf(`
res, err = d.DB.GetAll(fmt.Sprintf(`
SELECT INDEX_NAME,COLUMN_NAME,CHAR_LENGTH FROM USER_IND_COLUMNS
WHERE TABLE_NAME = '%s'
AND INDEX_NAME IN(SELECT INDEX_NAME FROM USER_INDEXES WHERE TABLE_NAME='%s' AND UNIQUENESS='UNIQUE')
@ -203,7 +213,7 @@ func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[str
return
}
func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
var fields []string
var values []string
var params []interface{}
@ -218,11 +228,11 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
case reflect.Slice:
fallthrough
case reflect.Array:
return db.db.doBatchInsert(link, table, data, option, batch...)
return d.DB.DoBatchInsert(link, table, data, option, batch...)
case reflect.Map:
fallthrough
case reflect.Struct:
dataMap = varToMapDeep(data)
dataMap = DataToMapDeep(data)
default:
return result, errors.New(fmt.Sprint("unsupported data type:", kind))
}
@ -231,7 +241,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
indexMap := make(map[string]string)
indexExists := false
if option != gINSERT_OPTION_DEFAULT {
index, err := db.getTableUniqueIndex(table)
index, err := d.getTableUniqueIndex(table)
if err != nil {
return nil, err
}
@ -253,7 +263,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
onStr := make([]string, 0)
updateStr := make([]string, 0)
charL, charR := db.db.getChars()
charL, charR := d.DB.GetChars()
for k, v := range dataMap {
k = strings.ToUpper(k)
@ -279,7 +289,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
}
if link == nil {
if link, err = db.db.Master(); err != nil {
if link, err = d.DB.Master(); err != nil {
return nil, err
}
}
@ -294,9 +304,9 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2,
strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","),
)
return db.db.doExec(link, tmp, params...)
return d.DB.DoExec(link, tmp, params...)
case gINSERT_OPTION_IGNORE:
return db.db.doExec(link,
return d.DB.DoExec(link,
fmt.Sprintf(
"INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)",
table, strings.Join(indexs, ","), table, strings.Join(fields, ","), strings.Join(values, ","),
@ -305,7 +315,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
}
}
return db.db.doExec(
return d.DB.DoExec(
link,
fmt.Sprintf(
"INSERT INTO %s(%s) VALUES(%s)",
@ -314,7 +324,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option
params...)
}
func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
var keys []string
var values []string
var params []interface{}
@ -342,12 +352,12 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
case reflect.Array:
listMap = make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
listMap[i] = varToMapDeep(rv.Index(i).Interface())
listMap[i] = DataToMapDeep(rv.Index(i).Interface())
}
case reflect.Map:
fallthrough
case reflect.Struct:
listMap = List{Map(varToMapDeep(list))}
listMap = List{Map(DataToMapDeep(list))}
default:
return result, errors.New(fmt.Sprint("unsupported list type:", kind))
}
@ -357,7 +367,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
return result, errors.New("empty data list")
}
if link == nil {
if link, err = db.db.Master(); err != nil {
if link, err = d.DB.Master(); err != nil {
return
}
}
@ -368,14 +378,14 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
holders = append(holders, "?")
}
batchResult := new(batchSqlResult)
charL, charR := db.db.getChars()
charL, charR := d.DB.GetChars()
keyStr := charL + strings.Join(keys, charL+","+charR) + charR
valueHolderStr := strings.Join(holders, ",")
// 当操作类型非insert时调用单笔的insert功能
if option != gINSERT_OPTION_DEFAULT {
for _, v := range listMap {
r, err := db.doInsert(link, table, v, option, 1)
r, err := d.DB.DoInsert(link, table, v, option, 1)
if err != nil {
return r, err
}
@ -402,10 +412,9 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
params = append(params, listMap[i][k])
}
values = append(values, valueHolderStr)
intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr))
if len(intoStr) == batchNum {
r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
if err != nil {
return r, err
}
@ -421,7 +430,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o
}
// 处理最后不构成指定批量的数据
if len(intoStr) > 0 {
r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
if err != nil {
return r, err
}

View File

@ -21,12 +21,21 @@ import (
"github.com/gogf/gf/text/gregex"
)
type dbPgsql struct {
*dbBase
// DriverPgsql is the driver for postgresql database.
type DriverPgsql struct {
*Core
}
// New creates and returns a database object for postgresql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverPgsql{
Core: core,
}, nil
}
// Open creates and returns a underlying sql.DB object for pgsql.
func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) {
func (d *DriverPgsql) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
@ -44,13 +53,13 @@ func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) {
}
}
// getChars returns the security char for this type of database.
func (db *dbPgsql) getChars() (charLeft string, charRight string) {
// GetChars returns the security char for this type of database.
func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
return "\"", "\""
}
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
func (db *dbPgsql) handleSqlBeforeExec(sql string) string {
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string "$x".
sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
@ -58,13 +67,14 @@ func (db *dbPgsql) handleSqlBeforeExec(sql string) string {
return fmt.Sprintf("$%d", index)
})
sql, _ = gregex.ReplaceString(` LIMIT (\d+),\s*(\d+)`, ` LIMIT $1 OFFSET $2`, sql)
return sql
return sql, args
}
// Tables retrieves and returns the tables of current schema.
func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) {
var result Result
link, err := db.getSlave(schema...)
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
@ -73,7 +83,7 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
if len(schema) > 0 && schema[0] != "" {
query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0])
}
result, err = db.doGetAll(link, query)
result, err = d.DB.DoGetAll(link, query)
if err != nil {
return
}
@ -86,25 +96,25 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) {
}
// TableFields retrieves and returns the fields information of specified table of current schema.
func (db *dbPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function TableFields supports only single table operations")
}
table, _ = gregex.ReplaceString("\"", "", table)
checkSchema := db.schema.Val()
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := db.cache.GetOrSetFunc(
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`pgsql_table_fields_%s_%s`, table, checkSchema), func() interface{} {
var result Result
var link *sql.DB
link, err = db.getSlave(checkSchema)
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = db.doGetAll(link, fmt.Sprintf(`
result, err = d.DB.DoGetAll(link, fmt.Sprintf(`
SELECT a.attname AS field, t.typname AS type FROM pg_class c, pg_attribute a
LEFT OUTER JOIN pg_description b ON a.attrelid=b.objoid AND a.attnum = b.objsubid,pg_type t
WHERE c.relname = '%s' and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid

View File

@ -16,12 +16,21 @@ import (
"github.com/gogf/gf/text/gstr"
)
type dbSqlite struct {
*dbBase
// DriverSqlite is the driver for sqlite database.
type DriverSqlite struct {
*Core
}
// New creates and returns a database object for sqlite.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *DriverSqlite) New(core *Core, node *ConfigNode) (DB, error) {
return &DriverSqlite{
Core: core,
}, nil
}
// Open creates and returns a underlying sql.DB object for sqlite.
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
@ -36,30 +45,31 @@ func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
}
}
// getChars returns the security char for this type of database.
func (db *dbSqlite) getChars() (charLeft string, charRight string) {
// GetChars returns the security char for this type of database.
func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
return "`", "`"
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
// @todo 需要增加对Save方法的支持可使用正则来实现替换
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
func (d *DriverSqlite) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
// TODO
func (db *dbSqlite) Tables(schema ...string) (tables []string, err error) {
func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) {
return
}
// TableFields retrieves and returns the fields information of specified table of current schema.
// TODO
func (db *dbSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function TableFields supports only single table operations")
}
return
}
// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver.
// @todo 需要增加对Save方法的支持可使用正则来实现替换
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
func (db *dbSqlite) handleSqlBeforeExec(sql string) string {
return sql
}

View File

@ -51,7 +51,54 @@ var (
quoteWordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
)
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
// GetInsertOperationByOption returns proper insert option with given parameter <option>.
func GetInsertOperationByOption(option int) string {
var operator string
switch option {
case gINSERT_OPTION_REPLACE:
operator = "REPLACE"
case gINSERT_OPTION_IGNORE:
operator = "INSERT IGNORE"
default:
operator = "INSERT"
}
return operator
}
// DataToMapDeep converts struct object to map type recursively.
func DataToMapDeep(obj interface{}) map[string]interface{} {
data := gconv.Map(obj, ORM_TAG_FOR_STRUCT)
for key, value := range data {
rv := reflect.ValueOf(value)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Struct:
// The underlying driver supports time.Time/*time.Time types.
if _, ok := value.(time.Time); ok {
continue
}
if _, ok := value.(*time.Time); ok {
continue
}
// Use string conversion in default.
if s, ok := value.(apiString); ok {
data[key] = s.String()
continue
}
delete(data, key)
for k, v := range DataToMapDeep(value) {
data[k] = v
}
}
}
return data
}
// QuotePrefixTableName adds prefix string and quote chars for the table. It handles table string like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut", "user.user u".
//
// Note that, this will automatically checks the table prefix whether already added, if true it does
@ -196,7 +243,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
newArgs = formatWhereInterfaces(db, gconv.Interfaces(where), buffer, newArgs)
case reflect.Map:
for key, value := range varToMapDeep(where) {
for key, value := range DataToMapDeep(where) {
if omitEmpty && empty.IsEmpty(value) {
continue
}
@ -218,7 +265,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
})
break
}
for key, value := range varToMapDeep(where) {
for key, value := range DataToMapDeep(where) {
if omitEmpty && empty.IsEmpty(value) {
continue
}
@ -269,7 +316,7 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new
// formatWhereKeyValue handles each key-value pair of the parameter map.
func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} {
key = db.quoteWord(key)
key = db.QuoteWord(key)
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}
@ -314,39 +361,6 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
return newArgs
}
// varToMapDeep converts struct object to map type recursively.
func varToMapDeep(obj interface{}) map[string]interface{} {
data := gconv.Map(obj, ORM_TAG_FOR_STRUCT)
for key, value := range data {
rv := reflect.ValueOf(value)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Struct:
// The underlying driver supports time.Time/*time.Time types.
if _, ok := value.(time.Time); ok {
continue
}
if _, ok := value.(*time.Time); ok {
continue
}
// Use string conversion in default.
if s, ok := value.(apiString); ok {
data[key] = s.String()
continue
}
delete(data, key)
for k, v := range varToMapDeep(value) {
data[k] = v
}
}
}
return data
}
// handleArguments is a nice function which handles the query and its arguments before committing to
// underlying driver.
func handleArguments(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
@ -422,20 +436,6 @@ func formatError(err error, query string, args ...interface{}) error {
return err
}
// getInsertOperationByOption returns proper insert option with given parameter <option>.
func getInsertOperationByOption(option int) string {
var operator string
switch option {
case gINSERT_OPTION_REPLACE:
operator = "REPLACE"
case gINSERT_OPTION_IGNORE:
operator = "INSERT IGNORE"
default:
operator = "INSERT"
}
return operator
}
// bindArgsToQuery binds the arguments to the query string and returns a complete
// sql string, just for debugging.
func bindArgsToQuery(query string, args []interface{}) string {

View File

@ -68,10 +68,10 @@ const (
// Table creates and returns a new ORM model from given schema.
// The parameter <tables> can be more than one table names, like :
// "user", "user u", "user, user_detail", "user u, user_detail ud"
func (bs *dbBase) Table(table string) *Model {
table = bs.db.handleTableName(table)
func (c *Core) Table(table string) *Model {
table = c.DB.QuotePrefixTableName(table)
return &Model{
db: bs.db,
db: c.DB,
tablesInit: table,
tables: table,
fields: "*",
@ -82,23 +82,23 @@ func (bs *dbBase) Table(table string) *Model {
}
}
// Model is alias of dbBase.Table.
// See dbBase.Table.
func (bs *dbBase) Model(table string) *Model {
return bs.db.Table(table)
// Model is alias of Core.Table.
// See Core.Table.
func (c *Core) Model(table string) *Model {
return c.DB.Table(table)
}
// From is alias of dbBase.Table.
// See dbBase.Table.
// From is alias of Core.Table.
// See Core.Table.
// Deprecated.
func (bs *dbBase) From(table string) *Model {
return bs.db.Table(table)
func (c *Core) From(table string) *Model {
return c.DB.Table(table)
}
// Table acts like dbBase.Table except it operates on transaction.
// See dbBase.Table.
// Table acts like Core.Table except it operates on transaction.
// See Core.Table.
func (tx *TX) Table(table string) *Model {
table = tx.db.handleTableName(table)
table = tx.db.QuotePrefixTableName(table)
return &Model{
db: tx.db,
tx: tx,
@ -217,21 +217,21 @@ func (m *Model) getModel() *Model {
// LeftJoin does "LEFT JOIN ... ON ..." statement on the model.
func (m *Model) LeftJoin(table string, on string) *Model {
model := m.getModel()
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.handleTableName(table), on)
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
return model
}
// RightJoin does "RIGHT JOIN ... ON ..." statement on the model.
func (m *Model) RightJoin(table string, on string) *Model {
model := m.getModel()
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.handleTableName(table), on)
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
return model
}
// InnerJoin does "INNER JOIN ... ON ..." statement on the model.
func (m *Model) InnerJoin(table string, on string) *Model {
model := m.getModel()
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.handleTableName(table), on)
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on)
return model
}
@ -403,7 +403,7 @@ func (m *Model) Or(where interface{}, args ...interface{}) *Model {
// Group sets the "GROUP BY" statement for the model.
func (m *Model) Group(groupBy string) *Model {
model := m.getModel()
model.groupBy = m.db.quoteString(groupBy)
model.groupBy = m.db.QuoteString(groupBy)
return model
}
@ -417,7 +417,7 @@ func (m *Model) GroupBy(groupBy string) *Model {
// Order sets the "ORDER BY" statement for the model.
func (m *Model) Order(orderBy string) *Model {
model := m.getModel()
model.orderBy = m.db.quoteString(orderBy)
model.orderBy = m.db.QuoteString(orderBy)
return model
}
@ -540,11 +540,11 @@ func (m *Model) Data(data ...interface{}) *Model {
case reflect.Slice, reflect.Array:
list := make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
list[i] = varToMapDeep(rv.Index(i).Interface())
list[i] = DataToMapDeep(rv.Index(i).Interface())
}
model.data = list
case reflect.Map, reflect.Struct:
model.data = varToMapDeep(data[0])
model.data = DataToMapDeep(data[0])
default:
model.data = data[0]
}
@ -586,7 +586,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.
if m.batch > 0 {
batch = m.batch
}
return m.db.doBatchInsert(
return m.db.DoBatchInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(list),
@ -595,7 +595,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.
)
} else if data, ok := m.data.(Map); ok {
// Single insert.
return m.db.doInsert(
return m.db.DoInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(data),
@ -626,7 +626,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
if m.batch > 0 {
batch = m.batch
}
return m.db.doBatchInsert(
return m.db.DoBatchInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(list),
@ -635,7 +635,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
)
} else if data, ok := m.data.(Map); ok {
// Single insert.
return m.db.doInsert(
return m.db.DoInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(data),
@ -669,7 +669,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
if m.batch > 0 {
batch = m.batch
}
return m.db.doBatchInsert(
return m.db.DoBatchInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(list),
@ -678,7 +678,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
)
} else if data, ok := m.data.(Map); ok {
// Single save.
return m.db.doInsert(
return m.db.DoInsert(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(data),
@ -712,7 +712,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
return nil, errors.New("updating table with empty data")
}
condition, conditionArgs := m.formatCondition(false)
return m.db.doUpdate(
return m.db.DoUpdate(
m.getLink(true),
m.tables,
m.filterDataForInsertOrUpdate(m.data),
@ -734,7 +734,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
}
}()
condition, conditionArgs := m.formatCondition(false)
return m.db.doDelete(m.getLink(true), m.tables, condition, conditionArgs...)
return m.db.DoDelete(m.getLink(true), m.tables, condition, conditionArgs...)
}
// Select is alias of Model.All.
@ -1045,7 +1045,7 @@ func (m *Model) doFilterDataMapForInsertOrUpdate(data Map, allowOmitEmpty bool)
// getLink returns the underlying database link object with configured <linkType> attribute.
// The parameter <master> specifies whether using the master node if master-slave configured.
func (m *Model) getLink(master bool) dbLink {
func (m *Model) getLink(master bool) Link {
if m.tx != nil {
return m.tx.tx
}
@ -1059,10 +1059,16 @@ func (m *Model) getLink(master bool) dbLink {
}
switch linkType {
case gLINK_TYPE_MASTER:
link, _ := m.db.getMaster(m.schema)
link, err := m.db.GetMaster(m.schema)
if err != nil {
panic(err)
}
return link
case gLINK_TYPE_SLAVE:
link, _ := m.db.getSlave(m.schema)
link, err := m.db.GetSlave(m.schema)
if err != nil {
panic(err)
}
return link
}
return nil
@ -1077,17 +1083,17 @@ func (m *Model) getAll(query string, args ...interface{}) (result Result, err er
if len(cacheKey) == 0 {
cacheKey = query + "/" + gconv.String(args)
}
if v := m.db.getCache().Get(cacheKey); v != nil {
if v := m.db.GetCache().Get(cacheKey); v != nil {
return v.(Result), nil
}
}
result, err = m.db.doGetAll(m.getLink(false), query, args...)
result, err = m.db.DoGetAll(m.getLink(false), query, args...)
// Cache the result.
if len(cacheKey) > 0 && err == nil {
if m.cacheDuration < 0 {
m.db.getCache().Remove(cacheKey)
m.db.GetCache().Remove(cacheKey)
} else {
m.db.getCache().Set(cacheKey, result, m.cacheDuration)
m.db.GetCache().Set(cacheKey, result, m.cacheDuration)
}
}
return result, err
@ -1113,7 +1119,7 @@ func (m *Model) getPrimaryKey() string {
// checkAndRemoveCache checks and remove the cache if necessary.
func (m *Model) checkAndRemoveCache() {
if m.cacheEnabled && m.cacheDuration < 0 && len(m.cacheName) > 0 {
m.db.getCache().Remove(m.cacheName)
m.db.GetCache().Remove(m.cacheName)
}
}

View File

@ -14,9 +14,9 @@ type Schema struct {
}
// Schema creates and returns a schema.
func (bs *dbBase) Schema(schema string) *Schema {
func (c *Core) Schema(schema string) *Schema {
return &Schema{
db: bs.db,
db: c.DB,
schema: schema,
}
}
@ -44,8 +44,8 @@ func (s *Schema) Table(table string) *Model {
return m
}
// Model is alias of dbBase.Table.
// See dbBase.Table.
// Model is alias of Core.Table.
// See Core.Table.
func (s *Schema) Model(table string) *Model {
return s.Table(table)
}

View File

@ -21,7 +21,7 @@ import (
// convertValue automatically checks and converts field value from database type
// to golang variable type.
func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{} {
func (c *Core) convertValue(fieldValue []byte, fieldType string) interface{} {
t, _ := gregex.ReplaceString(`\(.+\)`, "", fieldType)
t = strings.ToLower(t)
switch t {
@ -106,10 +106,10 @@ func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{}
}
// filterFields removes all key-value pairs which are not the field of given table.
func (bs *dbBase) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} {
func (c *Core) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} {
// It must use data copy here to avoid its changing the origin data map.
newDataMap := make(map[string]interface{}, len(data))
if fields, err := bs.db.TableFields(table, schema); err == nil {
if fields, err := c.DB.TableFields(table, schema); err == nil {
for k, v := range data {
if _, ok := fields[k]; ok {
newDataMap[k] = v

View File

@ -32,15 +32,15 @@ func (tx *TX) Rollback() error {
}
// Query does query operation on transaction.
// See dbBase.Query.
// See Core.Query.
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
return tx.db.doQuery(tx.tx, query, args...)
return tx.db.DoQuery(tx.tx, query, args...)
}
// Exec does none query operation on transaction.
// See dbBase.Exec.
// See Core.Exec.
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.db.doExec(tx.tx, query, args...)
return tx.db.DoExec(tx.tx, query, args...)
}
// Prepare creates a prepared statement for later queries or executions.
@ -49,7 +49,7 @@ func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
return tx.db.doPrepare(tx.tx, query)
return tx.db.DoPrepare(tx.tx, query)
}
// GetAll queries and returns data records from database.
@ -154,7 +154,7 @@ func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
//
// The parameter <batch> specifies the batch operation count when given data is slice.
func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...)
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...)
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
@ -167,7 +167,7 @@ func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result,
//
// The parameter <batch> specifies the batch operation count when given data is slice.
func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...)
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...)
}
// Replace does "REPLACE INTO ..." statement for the table.
@ -183,7 +183,7 @@ func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Re
// If given data is type of slice, it then does batch replacing, and the optional parameter
// <batch> specifies the batch operation count.
func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...)
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...)
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
@ -198,31 +198,31 @@ func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result,
// If given data is type of slice, it then does batch saving, and the optional parameter
// <batch> specifies the batch operation count.
func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...)
return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...)
}
// BatchInsert batch inserts data.
// The parameter <list> must be type of slice of map or struct.
func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...)
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...)
}
// BatchInsert batch inserts data with ignore option.
// The parameter <list> must be type of slice of map or struct.
func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...)
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...)
}
// BatchReplace batch replaces data.
// The parameter <list> must be type of slice of map or struct.
func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...)
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...)
}
// BatchSave batch replaces data.
// The parameter <list> must be type of slice of map or struct.
func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...)
return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...)
}
// Update does "UPDATE ... " statement for the table.
@ -244,7 +244,7 @@ func (tx *TX) Update(table string, data interface{}, condition interface{}, args
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return tx.db.doUpdate(tx.tx, table, data, newWhere, newArgs...)
return tx.db.DoUpdate(tx.tx, table, data, newWhere, newArgs...)
}
// Delete does "DELETE FROM ... " statement for the table.
@ -263,5 +263,5 @@ func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return tx.db.doDelete(tx.tx, table, newWhere, newArgs...)
return tx.db.DoDelete(tx.tx, table, newWhere, newArgs...)
}

View File

@ -0,0 +1,74 @@
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gdb_test
import (
"github.com/gogf/gf/container/gtype"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/test/gtest"
"testing"
)
// MyDriver is a custom database driver, which is used for testing only.
// For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver
// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit.
// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and
// then gdb.DriverMysql.HandleSqlBeforeCommit.
// You can call it sql "HOOK" or "HiJack" as your will.
type MyDriver struct {
*gdb.DriverMysql
}
var (
customDriverName = "MyDriver"
latestSqlString = gtype.NewString() // For simplifying unit testing only.
)
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
return &MyDriver{
&gdb.DriverMysql{
Core: core,
},
}, nil
}
// HandleSqlBeforeCommit handles the sql before posts it to database.
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
func (d *MyDriver) HandleSqlBeforeCommit(link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
latestSqlString.Set(sql)
return d.DriverMysql.HandleSqlBeforeCommit(link, sql, args)
}
func init() {
// It here registers my custom driver in package initialization function "init".
// You can later use this type in the database configuration.
gdb.Register(customDriverName, &MyDriver{})
}
func Test_Custom_Driver(t *testing.T) {
gdb.AddConfigNode("driver-test", gdb.ConfigNode{
Host: "127.0.0.1",
Port: "3306",
User: "root",
Pass: "12345678",
Name: "test",
Type: customDriverName,
Role: "master",
Charset: "utf8",
})
gtest.Case(t, func() {
gtest.Assert(latestSqlString.Val(), "")
sqlString := "select 10000"
value, err := g.DB("driver-test").GetValue(sqlString)
gtest.Assert(err, nil)
gtest.Assert(value, 10000)
gtest.Assert(latestSqlString.Val(), sqlString)
})
}

View File

@ -681,6 +681,28 @@ func Test_Model_Struct(t *testing.T) {
})
}
func Test_Model_Struct_CustomType(t *testing.T) {
table := createInitTable()
defer dropTable(table)
type MyInt int
gtest.Case(t, func() {
type User struct {
Id MyInt
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
err := db.Table(table).Where("id=1").Struct(user)
gtest.Assert(err, nil)
gtest.Assert(user.NickName, "name_1")
gtest.Assert(user.CreateTime.String(), "2018-10-24 10:00:00")
})
}
func Test_Model_Structs(t *testing.T) {
table := createInitTable()
defer dropTable(table)

View File

@ -54,9 +54,9 @@ type PoolStats struct {
}
const (
gDEFAULT_POOL_IDLE_TIMEOUT = 60 * time.Second
gDEFAULT_POOL_IDLE_TIMEOUT = 30 * time.Second
gDEFAULT_POOL_CONN_TIMEOUT = 10 * time.Second
gDEFAULT_POOL_MAX_LIFE_TIME = 60 * time.Second
gDEFAULT_POOL_MAX_LIFE_TIME = 30 * time.Second
)
var (
@ -80,6 +80,7 @@ func New(config Config) *Redis {
config: config,
pool: pools.GetOrSetFuncLock(fmt.Sprintf("%v", config), func() interface{} {
return &redis.Pool{
Wait: true,
IdleTimeout: config.IdleTimeout,
MaxActive: config.MaxActive,
MaxIdle: config.MaxIdle,

View File

@ -210,6 +210,11 @@ func Test_Error(t *testing.T) {
func Test_Bool(t *testing.T) {
gtest.Case(t, func() {
redis := gredis.New(config)
defer func() {
redis.Do("DEL", "key-true")
redis.Do("DEL", "key-false")
}()
_, err := redis.Do("SET", "key-true", true)
gtest.Assert(err, nil)
@ -230,6 +235,8 @@ func Test_Int(t *testing.T) {
gtest.Case(t, func() {
redis := gredis.New(config)
key := guuid.New()
defer redis.Do("DEL", key)
_, err := redis.Do("SET", key, 1)
gtest.Assert(err, nil)
@ -243,6 +250,8 @@ func Test_HSet(t *testing.T) {
gtest.Case(t, func() {
redis := gredis.New(config)
key := guuid.New()
defer redis.Do("DEL", key)
_, err := redis.Do("HSET", key, "name", "john")
gtest.Assert(err, nil)
@ -251,3 +260,24 @@ func Test_HSet(t *testing.T) {
gtest.Assert(r.Strings(), g.ArrayStr{"name", "john"})
})
}
func Test_HGetAll(t *testing.T) {
gtest.Case(t, func() {
var err error
redis := gredis.New(config)
key := guuid.New()
defer redis.Do("DEL", key)
_, err = redis.Do("HSET", key, "id", "100")
gtest.Assert(err, nil)
_, err = redis.Do("HSET", key, "name", "john")
gtest.Assert(err, nil)
r, err := redis.DoVar("HGETALL", key)
gtest.Assert(err, nil)
gtest.Assert(r.Map(), g.MapStrAny{
"id": 100,
"name": "john",
})
})
}

View File

@ -10,6 +10,9 @@ package gdebug
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
@ -19,7 +22,6 @@ import (
"github.com/gogf/gf/encoding/ghash"
"github.com/gogf/gf/crypto/gmd5"
"github.com/gogf/gf/os/gfile"
)
const (
@ -31,20 +33,30 @@ var (
goRootForFilter = runtime.GOROOT() // goRootForFilter is used for stack filtering purpose.
binaryVersion = "" // The version of current running binary(uint64 hex).
binaryVersionMd5 = "" // The version of current running binary(MD5).
selfPath = "" // Current running binary absolute path.
)
func init() {
if goRootForFilter != "" {
goRootForFilter = strings.Replace(goRootForFilter, "\\", "/", -1)
}
// Initialize internal package variable: selfPath.
selfPath, _ := exec.LookPath(os.Args[0])
if selfPath != "" {
selfPath, _ = filepath.Abs(selfPath)
}
if selfPath == "" {
selfPath, _ = filepath.Abs(os.Args[0])
}
}
// BinVersion returns the version of current running binary.
// It uses ghash.BKDRHash+BASE36 algorithm to calculate the unique version of the binary.
func BinVersion() string {
if binaryVersion == "" {
binaryContent, _ := ioutil.ReadFile(selfPath)
binaryVersion = strconv.FormatInt(
int64(ghash.BKDRHash(gfile.GetBytes(gfile.SelfPath()))),
int64(ghash.BKDRHash(binaryContent)),
36,
)
}
@ -55,7 +67,7 @@ func BinVersion() string {
// It uses MD5 algorithm to calculate the unique version of the binary.
func BinVersionMd5() string {
if binaryVersionMd5 == "" {
binaryVersionMd5, _ = gmd5.EncryptFile(gfile.SelfPath())
binaryVersionMd5, _ = gmd5.EncryptFile(selfPath)
}
return binaryVersionMd5
}

View File

@ -0,0 +1,39 @@
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gcompress_test
import (
"testing"
"github.com/gogf/gf/encoding/gcompress"
"github.com/gogf/gf/test/gtest"
)
func Test_Gzip_UnGzip(t *testing.T) {
src := "Hello World!!"
gzip := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xff,
0xf2, 0x48, 0xcd, 0xc9, 0xc9,
0x57, 0x08, 0xcf, 0x2f, 0xca,
0x49, 0x51, 0x54, 0x04, 0x04,
0x00, 0x00, 0xff, 0xff, 0x9d,
0x24, 0xa8, 0xd1, 0x0d, 0x00,
0x00, 0x00,
}
arr := []byte(src)
data, _ := gcompress.Gzip(arr)
gtest.Assert(data, gzip)
data, _ = gcompress.UnGzip(gzip)
gtest.Assert(data, arr)
data, _ = gcompress.UnGzip(gzip[1:])
gtest.Assert(data, nil)
}

View File

@ -0,0 +1,153 @@
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gcompress_test
import (
"bytes"
"github.com/gogf/gf/debug/gdebug"
"github.com/gogf/gf/encoding/gcompress"
"github.com/gogf/gf/os/gfile"
"github.com/gogf/gf/os/gtime"
"testing"
"github.com/gogf/gf/test/gtest"
)
func Test_ZipPath(t *testing.T) {
// file
gtest.Case(t, func() {
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1", "1.txt")
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
gtest.Assert(gfile.Exists(dstPath), false)
err := gcompress.ZipPath(srcPath, dstPath)
gtest.Assert(err, nil)
gtest.Assert(gfile.Exists(dstPath), true)
defer gfile.Remove(dstPath)
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
err = gfile.Mkdir(tempDirPath)
gtest.Assert(err, nil)
err = gcompress.UnZipFile(dstPath, tempDirPath)
gtest.Assert(err, nil)
defer gfile.Remove(tempDirPath)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "1.txt")),
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
)
})
// directory
gtest.Case(t, func() {
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
pwd := gfile.Pwd()
err := gfile.Chdir(srcPath)
defer gfile.Chdir(pwd)
gtest.Assert(err, nil)
gtest.Assert(gfile.Exists(dstPath), false)
err = gcompress.ZipPath(srcPath, dstPath)
gtest.Assert(err, nil)
gtest.Assert(gfile.Exists(dstPath), true)
defer gfile.Remove(dstPath)
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
err = gfile.Mkdir(tempDirPath)
gtest.Assert(err, nil)
err = gcompress.UnZipFile(dstPath, tempDirPath)
gtest.Assert(err, nil)
defer gfile.Remove(tempDirPath)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "zip", "path1", "1.txt")),
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "zip", "path2", "2.txt")),
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
)
})
// multiple paths joined using char ','
gtest.Case(t, func() {
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1")
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2")
dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip")
pwd := gfile.Pwd()
err := gfile.Chdir(srcPath)
defer gfile.Chdir(pwd)
gtest.Assert(err, nil)
gtest.Assert(gfile.Exists(dstPath), false)
err = gcompress.ZipPath(srcPath1+", "+srcPath2, dstPath)
gtest.Assert(err, nil)
gtest.Assert(gfile.Exists(dstPath), true)
defer gfile.Remove(dstPath)
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
err = gfile.Mkdir(tempDirPath)
gtest.Assert(err, nil)
zipContent := gfile.GetBytes(dstPath)
gtest.AssertGT(len(zipContent), 0)
err = gcompress.UnZipContent(zipContent, tempDirPath)
gtest.Assert(err, nil)
defer gfile.Remove(tempDirPath)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")),
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")),
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
)
})
}
func Test_ZipPathWriter(t *testing.T) {
gtest.Case(t, func() {
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip")
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1")
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2")
pwd := gfile.Pwd()
err := gfile.Chdir(srcPath)
defer gfile.Chdir(pwd)
gtest.Assert(err, nil)
writer := bytes.NewBuffer(nil)
gtest.Assert(writer.Len(), 0)
err = gcompress.ZipPathWriter(srcPath1+", "+srcPath2, writer)
gtest.Assert(err, nil)
gtest.AssertGT(writer.Len(), 0)
tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
err = gfile.Mkdir(tempDirPath)
gtest.Assert(err, nil)
zipContent := writer.Bytes()
gtest.AssertGT(len(zipContent), 0)
err = gcompress.UnZipContent(zipContent, tempDirPath)
gtest.Assert(err, nil)
defer gfile.Remove(tempDirPath)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")),
gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")),
)
gtest.Assert(
gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")),
gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")),
)
})
}

View File

@ -13,7 +13,7 @@ import (
"github.com/gogf/gf/test/gtest"
)
func TestZlib(t *testing.T) {
func Test_Zlib_UnZlib(t *testing.T) {
gtest.Case(t, func() {
src := "hello, world\n"
dst := []byte{120, 156, 202, 72, 205, 201, 201, 215, 81, 40, 207, 47, 202, 73, 225, 2, 4, 0, 0, 255, 255, 33, 231, 4, 147}
@ -31,30 +31,4 @@ func TestZlib(t *testing.T) {
data, _ = gcompress.UnZlib(dst[1:])
gtest.Assert(data, nil)
})
}
func TestGzip(t *testing.T) {
src := "Hello World!!"
gzip := []byte{
0x1f, 0x8b, 0x08, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xff,
0xf2, 0x48, 0xcd, 0xc9, 0xc9,
0x57, 0x08, 0xcf, 0x2f, 0xca,
0x49, 0x51, 0x54, 0x04, 0x04,
0x00, 0x00, 0xff, 0xff, 0x9d,
0x24, 0xa8, 0xd1, 0x0d, 0x00,
0x00, 0x00,
}
arr := []byte(src)
data, _ := gcompress.Gzip(arr)
gtest.Assert(data, gzip)
data, _ = gcompress.UnGzip(gzip)
gtest.Assert(data, arr)
data, _ = gcompress.UnGzip(gzip[1:])
gtest.Assert(data, nil)
}

View File

@ -9,6 +9,7 @@ package gcompress
import (
"archive/zip"
"bytes"
"github.com/gogf/gf/internal/intlog"
"io"
"os"
"path/filepath"
@ -32,7 +33,15 @@ func ZipPath(paths, dest string, prefix ...string) error {
return err
}
defer writer.Close()
return ZipPathWriter(paths, writer, prefix...)
zipWriter := zip.NewWriter(writer)
defer zipWriter.Close()
for _, path := range strings.Split(paths, ",") {
path = strings.TrimSpace(path)
if err := doZipPathWriter(path, gfile.RealPath(dest), zipWriter, prefix...); err != nil {
return err
}
}
return nil
}
// ZipPathWriter compresses <paths> to <writer> using zip compressing algorithm.
@ -45,17 +54,21 @@ func ZipPathWriter(paths string, writer io.Writer, prefix ...string) error {
defer zipWriter.Close()
for _, path := range strings.Split(paths, ",") {
path = strings.TrimSpace(path)
if err := doZipPathWriter(path, zipWriter, prefix...); err != nil {
if err := doZipPathWriter(path, "", zipWriter, prefix...); err != nil {
return err
}
}
return nil
}
func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error {
// doZipPathWriter compresses the file of given <path> and writes the content to <zipWriter>.
// The parameter <exclude> specifies the exclusive file path that is not compressed to <zipWriter>,
// commonly the destination zip file path.
// The unnecessary parameter <prefix> indicates the path prefix for zip file.
func doZipPathWriter(path string, exclude string, zipWriter *zip.Writer, prefix ...string) error {
var err error
var files []string
realPath, err := gfile.Search(path)
path, err = gfile.Search(path)
if err != nil {
return err
}
@ -80,7 +93,11 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error
}
headerPrefix = strings.Replace(headerPrefix, "//", "/", -1)
for _, file := range files {
err := zipFile(file, headerPrefix+gfile.Dir(file[len(realPath):]), zipWriter)
if exclude == file {
intlog.Printf(`exclude file path: %s`, file)
continue
}
err := zipFile(file, headerPrefix+gfile.Dir(file[len(path):]), zipWriter)
if err != nil {
return err
}
@ -101,10 +118,10 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error
}
// UnZipFile decompresses <archive> to <dest> using zip compressing algorithm.
// The parameter <path> specifies the unzipped path of <archive>,
// The optional parameter <path> specifies the unzipped path of <archive>,
// which can be used to specify part of the archive file to unzip.
//
// Note thate the parameter <dest> should be a directory.
// Note that the parameter <dest> should be a directory.
func UnZipFile(archive, dest string, path ...string) error {
readerCloser, err := zip.OpenReader(archive)
if err != nil {
@ -118,7 +135,7 @@ func UnZipFile(archive, dest string, path ...string) error {
// The parameter <path> specifies the unzipped path of <archive>,
// which can be used to specify part of the archive file to unzip.
//
// Note thate the parameter <dest> should be a directory.
// Note that the parameter <dest> should be a directory.
func UnZipContent(data []byte, dest string, path ...string) error {
reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
if err != nil {
@ -178,6 +195,8 @@ func unZipFileWithReader(reader *zip.Reader, dest string, path ...string) error
return nil
}
// zipFile compresses the file of given <path> and writes the content to <zw>.
// The parameter <prefix> indicates the path prefix for zip file.
func zipFile(path string, prefix string, zw *zip.Writer) error {
file, err := os.Open(path)
if err != nil {

View File

@ -0,0 +1 @@
This is a test file for zip compression purpose.

View File

@ -0,0 +1 @@
This is an another test file for zip compression purpose.

View File

@ -136,7 +136,7 @@ func Test_GetMap(t *testing.T) {
gtest.Assert(err, nil)
gtest.Assert(j.GetMap("n"), nil)
gtest.Assert(j.GetMap("m"), g.Map{"k": "v"})
gtest.Assert(j.GetMap("a"), nil)
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
})
}

View File

@ -95,7 +95,7 @@ func Test_GetVar(t *testing.T) {
gtest.Assert(j.GetVar("m").Map(), g.Map{"k": "v"})
gtest.Assert(j.GetVar("a").Interfaces(), g.Slice{1, 2, 3})
gtest.Assert(j.GetVar("a").Slice(), g.Slice{1, 2, 3})
gtest.Assert(j.GetVar("a").Array(), g.Slice{1, 2, 3})
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
})
}
@ -106,7 +106,7 @@ func Test_GetMap(t *testing.T) {
gtest.AssertNE(j, nil)
gtest.Assert(j.GetMap("n"), nil)
gtest.Assert(j.GetMap("m"), g.Map{"k": "v"})
gtest.Assert(j.GetMap("a"), nil)
gtest.Assert(j.GetMap("a"), g.Map{"1": "2", "3": nil})
})
}

3
go.mod
View File

@ -7,8 +7,8 @@ require (
github.com/clbanning/mxj v1.8.4
github.com/fatih/structs v1.1.0
github.com/fsnotify/fsnotify v1.4.7
github.com/gf-third/mysql v1.4.2
github.com/gf-third/yaml v1.0.1
github.com/go-sql-driver/mysql v1.5.0
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/uuid v1.1.1
github.com/gorilla/websocket v1.4.1
@ -17,5 +17,4 @@ require (
github.com/olekukonko/tablewriter v0.0.1
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e // indirect
golang.org/x/text v0.3.2
google.golang.org/appengine v1.6.5 // indirect
)

View File

@ -4,39 +4,38 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package rwmutex provides switch of concurrent safe feature for sync.RWMutex.
// Package rwmutex provides switch of concurrent safety feature for sync.RWMutex.
package rwmutex
import "sync"
// RWMutex is a sync.RWMutex with a switch of concurrent safe feature.
// If its attribute *sync.RWMutex is not nil, it means it's in concurrent safety usage.
// Its attribute *sync.RWMutex is nil in default, which makes this struct mush lightweight.
type RWMutex struct {
sync.RWMutex
safe bool
*sync.RWMutex
}
// New creates and returns a new *RWMutex.
// The parameter <safe> is used to specify whether using this mutex in concurrent-safety,
// The parameter <safe> is used to specify whether using this mutex in concurrent safety,
// which is false in default.
func New(safe ...bool) *RWMutex {
mu := new(RWMutex)
if len(safe) > 0 {
mu.safe = safe[0]
} else {
mu.safe = false
if len(safe) > 0 && safe[0] {
mu.RWMutex = new(sync.RWMutex)
}
return mu
}
// IsSafe checks and returns whether current mutex is in concurrent-safe usage.
func (mu *RWMutex) IsSafe() bool {
return mu.safe
return mu.RWMutex != nil
}
// Lock locks mutex for writing.
// It does nothing if it is not in concurrent-safe usage.
func (mu *RWMutex) Lock() {
if mu.safe {
if mu.RWMutex != nil {
mu.RWMutex.Lock()
}
}
@ -44,7 +43,7 @@ func (mu *RWMutex) Lock() {
// Unlock unlocks mutex for writing.
// It does nothing if it is not in concurrent-safe usage.
func (mu *RWMutex) Unlock() {
if mu.safe {
if mu.RWMutex != nil {
mu.RWMutex.Unlock()
}
}
@ -52,7 +51,7 @@ func (mu *RWMutex) Unlock() {
// RLock locks mutex for reading.
// It does nothing if it is not in concurrent-safe usage.
func (mu *RWMutex) RLock() {
if mu.safe {
if mu.RWMutex != nil {
mu.RWMutex.RLock()
}
}
@ -60,7 +59,7 @@ func (mu *RWMutex) RLock() {
// RUnlock unlocks mutex for reading.
// It does nothing if it is not in concurrent-safe usage.
func (mu *RWMutex) RUnlock() {
if mu.safe {
if mu.RWMutex != nil {
mu.RWMutex.RUnlock()
}
}

View File

@ -50,9 +50,7 @@ func niceCallFunc(f func()) {
defer func() {
if err := recover(); err != nil {
switch err {
case gEXCEPTION_EXIT:
fallthrough
case gEXCEPTION_EXIT_ALL:
case gEXCEPTION_EXIT, gEXCEPTION_EXIT_ALL:
return
default:
panic(err)

View File

@ -8,6 +8,7 @@ package ghttp
import (
"fmt"
"github.com/gogf/gf/container/gmap"
"github.com/gogf/gf/os/gres"
"github.com/gogf/gf/os/gview"
"net/http"
@ -31,6 +32,7 @@ type Request struct {
LeaveTime int64 // Request ending time in microseconds.
Middleware *Middleware // The middleware manager.
StaticFile *StaticFile // Static file object when static file serving.
Context *gmap.StrAnyMap // Custom context map for internal usage purpose.
handlers []*handlerParsedItem // All matched handlers containing handler, hook and middleware for this request .
hasHookHandler bool // A bool marking whether there's hook handler in the handlers for performance purpose.
hasServeHandler bool // A bool marking whether there's serving handler in the handlers for performance purpose.
@ -66,6 +68,7 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request {
Request: r,
Response: newResponse(s, w),
EnterTime: gtime.TimestampMilli(),
Context: gmap.NewStrAnyMap(),
}
request.Cookie = GetCookie(request)
request.Session = s.sessionManager.New(request.GetSessionId())

View File

@ -123,6 +123,7 @@ func (m *Middleware) Next() {
}, func(exception interface{}) {
m.request.error = gerror.Newf("%v", exception)
m.request.Response.WriteStatus(http.StatusInternalServerError, exception)
loop = false
})
}
// Check the http status code after all handler and middleware done.

View File

@ -9,6 +9,7 @@ package ghttp
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gogf/gf/container/gvar"
"github.com/gogf/gf/encoding/gjson"
"github.com/gogf/gf/encoding/gurl"
@ -302,5 +303,19 @@ func (r *Request) GetMultipartFiles(name string) []*multipart.FileHeader {
if v := form.File[name+"[]"]; len(v) > 0 {
return v
}
// Support "name[0]","name[1]","name[2]", etc. as array parameter.
key := ""
files := make([]*multipart.FileHeader, 0)
for i := 0; ; i++ {
key = fmt.Sprintf(`%s[%d]`, name, i)
if v := form.File[key]; len(v) > 0 {
files = append(files, v[0])
} else {
break
}
}
if len(files) > 0 {
return files
}
return nil
}

View File

@ -8,12 +8,12 @@ package ghttp
import (
"errors"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/os/gfile"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/util/grand"
"io"
"mime/multipart"
"os"
"strconv"
"strings"
)
@ -26,66 +26,67 @@ type UploadFile struct {
// UploadFiles is array type for *UploadFile.
type UploadFiles []*UploadFile
// Save saves the single uploading file to specified path.
// The parameter path can be either a directory or a file path. If <path> is a directory,
// it saves the uploading file to the directory using its original name. If <path> is a
// file path, it saves the uploading file to the file path.
// Save saves the single uploading file to directory path and returns the saved file name.
//
// The parameter <dirPath> should be a directory path or it returns error.
//
// The parameter <randomlyRename> specifies whether randomly renames the file name, which
// make sense if the <path> is a directory.
//
// Note that it will overwrite the target file if there's already a same name file exist.
func (f *UploadFile) Save(path string, randomlyRename ...bool) error {
func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename string, err error) {
if f == nil {
return nil
return
}
if !gfile.Exists(dirPath) {
if err = gfile.Mkdir(dirPath); err != nil {
return
}
} else if !gfile.IsDir(dirPath) {
return "", errors.New(`parameter "dirPath" should be a directory path`)
}
file, err := f.Open()
if err != nil {
return err
return "", err
}
defer file.Close()
var newFile *os.File
if gfile.IsDir(path) {
filename := gfile.Basename(f.Filename)
if len(randomlyRename) > 0 && randomlyRename[0] {
filename = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6))
filename = filename + gfile.Ext(f.Filename)
}
newFile, err = gfile.Create(gfile.Join(path, filename))
} else {
newFile, err = gfile.Create(path)
name := gfile.Basename(f.Filename)
if len(randomlyRename) > 0 && randomlyRename[0] {
name = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6))
name = name + gfile.Ext(f.Filename)
}
filePath := gfile.Join(dirPath, name)
newFile, err := gfile.Create(filePath)
if err != nil {
return err
return "", err
}
defer newFile.Close()
intlog.Printf(`save upload file: %s`, filePath)
if _, err := io.Copy(newFile, file); err != nil {
return err
return "", err
}
return nil
return gfile.Basename(filePath), nil
}
// Save saves all uploading files to specified directory path.
// Save saves all uploading files to specified directory path and returns the saved file names.
//
// The parameter <dirPath> should be a directory path or it returns error.
//
// The parameter <randomlyRename> specifies whether randomly renames all the file names.
func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) error {
func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames []string, err error) {
if len(fs) == 0 {
return nil
return nil, nil
}
if !gfile.IsDir(dirPath) {
return errors.New(`parameter "dirPath" should be a directory path`)
}
var err error
for _, f := range fs {
if err = f.Save(dirPath, randomlyRename...); err != nil {
return err
if filename, err := f.Save(dirPath, randomlyRename...); err != nil {
return filenames, err
} else {
filenames = append(filenames, filename)
}
}
return nil
return
}
// GetUploadFile retrieves and returns the uploading file with specified form name.

View File

@ -0,0 +1,64 @@
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp
import (
"fmt"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gpage"
)
// GetPage creates and returns the pagination object for given <totalSize> and <pageSize>.
// NOTE THAT the page parameter name from client is constantly defined as gpage.PAGE_NAME
// for simplification and convenience.
func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page {
// It must has Router object attribute.
if r.Router == nil {
panic("Router object not found")
}
url := *r.URL
urlTemplate := url.Path
uriHasPageName := false
// Check the page variable in the URI.
if len(r.Router.RegNames) > 0 {
for _, name := range r.Router.RegNames {
if name == gpage.PAGE_NAME {
uriHasPageName = true
break
}
}
if uriHasPageName {
if match, err := gregex.MatchString(r.Router.RegRule, url.Path); err == nil && len(match) > 0 {
if len(match) > len(r.Router.RegNames) {
urlTemplate = r.Router.Uri
for i, name := range r.Router.RegNames {
rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name)
if name == gpage.PAGE_NAME {
urlTemplate, _ = gregex.ReplaceString(rule, gpage.PAGE_PLACE_HOLDER, urlTemplate)
} else {
urlTemplate, _ = gregex.ReplaceString(rule, match[i+1], urlTemplate)
}
}
}
}
}
}
// Check the page variable in the query string.
if !uriHasPageName {
values := url.Query()
values.Set(gpage.PAGE_NAME, gpage.PAGE_PLACE_HOLDER)
url.RawQuery = values.Encode()
// Replace the encodes "{.page}" to "{.page}".
url.RawQuery = gstr.Replace(url.RawQuery, "%7B.page%7D", "{.page}")
}
if url.RawQuery != "" {
urlTemplate += "?" + url.RawQuery
}
return gpage.New(totalSize, pageSize, r.GetInt(gpage.PAGE_NAME), urlTemplate)
}

View File

@ -8,11 +8,10 @@
package ghttp
import (
"net/http"
"net/url"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"net/http"
"net/url"
)
// CORSOptions is the options for CORS feature.
@ -27,6 +26,20 @@ type CORSOptions struct {
AllowHeaders string // Access-Control-Allow-Headers
}
var (
// defaultAllowHeaders is the default allowed headers for CORS.
// It's defined another map for better header key searching performance.
defaultAllowHeaders = "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With"
defaultAllowHeadersMap = make(map[string]struct{})
)
func init() {
array := gstr.SplitAndTrim(defaultAllowHeaders, ",")
for _, header := range array {
defaultAllowHeadersMap[header] = struct{}{}
}
}
// DefaultCORSOptions returns the default CORS options,
// which allows any cross-domain request.
func (r *Response) DefaultCORSOptions() CORSOptions {
@ -34,9 +47,19 @@ func (r *Response) DefaultCORSOptions() CORSOptions {
AllowOrigin: "*",
AllowMethods: HTTP_METHODS,
AllowCredentials: "true",
AllowHeaders: "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With",
AllowHeaders: defaultAllowHeaders,
MaxAge: 3628800,
}
// Allow all client's custom headers in default.
if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
array := gstr.SplitAndTrim(headers, ",")
for _, header := range array {
if _, ok := defaultAllowHeadersMap[header]; !ok {
options.AllowHeaders += header + ","
}
}
}
// Allow all anywhere origin in default.
if origin := r.Request.Header.Get("Origin"); origin != "" {
options.AllowOrigin = origin
} else if referer := r.Request.Referer(); referer != "" {
@ -72,8 +95,26 @@ func (r *Response) CORS(options CORSOptions) {
}
// No continue service handling if it's OPTIONS request.
if gstr.Equal(r.Request.Method, "OPTIONS") {
// Request method handler searching.
// It here simply uses Server.routesMap attribute enhancing the searching performance.
if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" {
routerKey := ""
for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} {
for _, v := range []string{gDEFAULT_METHOD, method} {
routerKey = r.Server.routerMapKey("", v, r.Request.URL.Path, domain)
if r.Server.routesMap[routerKey] != nil {
if r.Status == 0 {
r.Status = http.StatusOK
}
// No continue serving.
r.Request.ExitAll()
}
}
}
}
// Cannot find the request serving handler, it then responses 404.
if r.Status == 0 {
r.Status = http.StatusOK
r.Status = http.StatusNotFound
}
r.Request.ExitAll()
}

View File

@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"github.com/gogf/gf/debug/gdebug"
"github.com/gogf/gf/internal/intlog"
"net/http"
"os"
"reflect"
@ -78,7 +79,7 @@ type (
// handlerItem is the registered handler for route handling,
// including middleware and hook functions.
handlerItem struct {
itemId int // Unique ID mark.
itemId int // Unique handler item id mark.
itemName string // Handler name, which is automatically retrieved from runtime stack when registered.
itemType int // Handler type: object/handler/controller/middleware/hook.
itemFunc HandlerFunc // Handler address.
@ -143,7 +144,7 @@ var (
// it is used for quick HTTP method searching using map.
methodsMap = make(map[string]struct{})
// serverMapping stores more than one server instances.
// serverMapping stores more than one server instances for current process.
// The key is the name of the server, and the value is its instance.
serverMapping = gmap.NewStrAnyMap(true)
@ -444,14 +445,20 @@ func (s *Server) GetRouterArray() []RouterItem {
}
// Run starts server listening in blocking way.
// It's commonly used for single server situation.
func (s *Server) Run() {
if err := s.Start(); err != nil {
s.Logger().Fatal(err)
}
// Blocking using channel.
<-s.closeChan
// Remove plugins.
if len(s.plugins) > 0 {
for _, p := range s.plugins {
intlog.Printf(`remove plugin: %s`, p.Name())
p.Remove()
}
}
s.Logger().Printf("[ghttp] %d: all servers shutdown", gproc.Pid())
}
@ -459,7 +466,17 @@ func (s *Server) Run() {
// It's commonly used in multiple servers situation.
func Wait() {
<-allDoneChan
// Remove plugins.
serverMapping.Iterator(func(k string, v interface{}) bool {
s := v.(*Server)
if len(s.plugins) > 0 {
for _, p := range s.plugins {
intlog.Printf(`remove plugin: %s`, p.Name())
p.Remove()
}
}
return true
})
glog.Printf("[ghttp] %d: all servers shutdown", gproc.Pid())
}

View File

@ -10,7 +10,7 @@ package ghttp
type Plugin interface {
Name() string // Name returns the name of the plugin.
Author() string // Author returns the author of the plugin.
Version() string // Version returns the version of the plugin.
Version() string // Version returns the version of the plugin, like "v1.0.0".
Description() string // Description returns the description of the plugin.
Install(s *Server) error // Install installs the plugin before server starts.
Remove() error // Remove removes the plugin.

View File

@ -24,11 +24,18 @@ const (
)
var (
// 用于服务函数的ID生成变量
// handlerIdGenerator is handler item id generator.
handlerIdGenerator = gtype.NewInt()
)
// 解析pattern
// routerMapKey creates and returns an unique router key for given parameters.
// This key is used for Server.routerMap attribute, which is mainly for checks for
// repeated router registering.
func (s *Server) routerMapKey(hook, method, path, domain string) string {
return hook + "%" + s.serveHandlerKey(method, path, domain)
}
// parsePattern parses the given pattern to domain, method and path variable.
func (s *Server) parsePattern(pattern string) (domain, method, path string, err error) {
path = strings.TrimSpace(pattern)
domain = gDEFAULT_DOMAIN
@ -48,16 +55,16 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err
if path == "" {
err = errors.New("invalid pattern: URI should not be empty")
}
// 去掉末尾的"/"符号,与路由匹配时处理一致
if path != "/" {
path = strings.TrimRight(path, "/")
}
return
}
// 路由注册处理方法。
// 非叶节点为哈希表检索节点按照URI注册的层级进行高效检索直至到叶子链表节点
// 叶子节点是链表,按照优先级进行排序,优先级高的排前面,按照遍历检索,按照哈希表层级检索后的叶子链表数据量不会很大,所以效率比较高;
// setHandler creates router item with given handler and pattern and registers the handler to the router tree.
// The router tree can be treated as a multilayer hash table, please refer to the comment in following codes.
// This function is called during server starts up, which cares little about the performance. What really cares
// is the well designed router storage structure for router searching when the request is under serving.
func (s *Server) setHandler(pattern string, handler *handlerItem) {
handler.itemId = handlerIdGenerator.Add(1)
domain, method, uri, err := s.parsePattern(pattern)
@ -69,30 +76,32 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
s.Logger().Fatal("invalid pattern:", pattern, "URI should lead with '/'")
return
}
// 注册地址记录及重复注册判断
regKey := s.handlerKey(handler.hookName, method, uri, domain)
// Repeated router checks, this feature can be disabled by server configuration.
routerKey := s.routerMapKey(handler.hookName, method, uri, domain)
if !s.config.RouteOverWrite {
switch handler.itemType {
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
if item, ok := s.routesMap[regKey]; ok {
if item, ok := s.routesMap[routerKey]; ok {
s.Logger().Fatalf(`duplicated route registry "%s", already registered at %s`, pattern, item[0].file)
return
}
}
}
// 注册的路由信息对象
// Create a new router by given parameter.
handler.router = &Router{
Uri: uri,
Domain: domain,
Method: method,
Method: strings.ToUpper(method),
Priority: strings.Count(uri[1:], "/"),
}
handler.router.RegRule, handler.router.RegNames = s.patternToRegRule(uri)
handler.router.RegRule, handler.router.RegNames = s.patternToRegular(uri)
if _, ok := s.serveTree[domain]; !ok {
s.serveTree[domain] = make(map[string]interface{})
}
// 当前节点的规则链表
// List array, very important for router registering.
// There may be multiple lists adding into this array when searching from root to leaf.
lists := make([]*glist.List, 0)
array := ([]string)(nil)
if strings.EqualFold("/", uri) {
@ -100,43 +109,58 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
} else {
array = strings.Split(uri[1:], "/")
}
// 键名"*fuzz"代表当前节点为模糊匹配节点,该节点也会有一个*list链表
// 键名"*list"代表链表,叶子节点和模糊匹配节点都有该属性,优先级越高越排前;
// Multilayer hash table:
// 1. Each node of the table is separated by URI path which is split by char '/'.
// 2. The key "*fuzz" specifies this node is a fuzzy node, which has no certain name.
// 3. The key "*list" is the list item of the node, MOST OF THE NODES HAVE THIS ITEM,
// especially the fuzzy node. NOTE THAT the fuzzy node must have the "*list" item,
// and the leaf node also has "*list" item. If the node is not a fuzzy node either
// a leaf, it neither has "*list" item.
// 2. The "*list" item is a list containing registered router items ordered by their
// priorities from high to low.
// 3. There may be repeated router items in the router lists. The lists' priorities
// from root to leaf are from low to high.
p := s.serveTree[domain]
for k, v := range array {
if len(v) == 0 {
for i, part := range array {
// Ignore empty URI part, like: /user//index
if part == "" {
continue
}
// 判断是否模糊匹配规则
if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, v) {
v = "*fuzz"
// 由于是模糊规则,因此这里会有一个*list用以将后续的路由规则加进来
// 检索会从叶子节点的链表往根节点按照优先级进行检索
// Check if it's a fuzzy node.
if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, part) {
part = "*fuzz"
// If it's a fuzzy node, it creates a "*list" item - which is a list - in the hash map.
// All the sub router items from this fuzzy node will also be added to its "*list" item.
if v, ok := p.(map[string]interface{})["*list"]; !ok {
p.(map[string]interface{})["*list"] = glist.New()
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
newListForFuzzy := glist.New()
p.(map[string]interface{})["*list"] = newListForFuzzy
lists = append(lists, newListForFuzzy)
} else {
lists = append(lists, v.(*glist.List))
}
}
// 属性层级数据写入
if _, ok := p.(map[string]interface{})[v]; !ok {
p.(map[string]interface{})[v] = make(map[string]interface{})
// Make a new bucket for current node.
if _, ok := p.(map[string]interface{})[part]; !ok {
p.(map[string]interface{})[part] = make(map[string]interface{})
}
p = p.(map[string]interface{})[v]
// 到达叶子节点往list中增加匹配规则(条件 v != "*fuzz" 是因为模糊节点的话在前面已经添加了*list链表)
if k == len(array)-1 && v != "*fuzz" {
// Loop to next bucket.
p = p.(map[string]interface{})[part]
// The leaf is a hash map and must have an item named "*list", which contains the router item.
// The leaf can be furthermore extended by adding more ket-value pairs into its map.
// Note that the `v != "*fuzz"` comparison is required as the list might be added in the former
// fuzzy checks.
if i == len(array)-1 && part != "*fuzz" {
if v, ok := p.(map[string]interface{})["*list"]; !ok {
p.(map[string]interface{})["*list"] = glist.New()
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
leafList := glist.New()
p.(map[string]interface{})["*list"] = leafList
lists = append(lists, leafList)
} else {
lists = append(lists, v.(*glist.List))
}
}
}
// 上面循环后得到的lists是该路由规则一路匹配下来相关的模糊匹配链表(注意不是这棵树所有的链表)。
// 下面从头开始遍历每个节点的模糊匹配链表,将该路由项插入进去(按照优先级高的放在lists链表的前面)
// It iterates the list array of <lists>, compares priorities and inserts the new router item in
// the proper position of each list. The priority of the list is ordered from high to low.
item := (*handlerItem)(nil)
for _, l := range lists {
pushed := false
@ -157,8 +181,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
}
}
// Initialize the route map item.
if _, ok := s.routesMap[regKey]; !ok {
s.routesMap[regKey] = make([]registeredRouteItem, 0)
if _, ok := s.routesMap[routerKey]; !ok {
s.routesMap[routerKey] = make([]registeredRouteItem, 0)
}
_, file, line := gdebug.CallerWithFilter(gFILTER_KEY)
routeItem := registeredRouteItem{
@ -168,35 +192,39 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
switch handler.itemType {
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
// Overwrite the route.
s.routesMap[regKey] = []registeredRouteItem{routeItem}
s.routesMap[routerKey] = []registeredRouteItem{routeItem}
default:
// Append the route.
s.routesMap[regKey] = append(s.routesMap[regKey], routeItem)
s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem)
}
}
// 对比两个handlerItem的优先级需要非常注意的是注意新老对比项的参数先后顺序。
// 返回值true表示newItem优先级比oldItem高会被添加链表中oldRouter的前面否则后面。
// 优先级比较规则:
// 1、中间件优先级最高按照添加顺序优先级执行
// 2、其他路由注册类型层级越深优先级越高(对比/数量)
// 3、模糊规则优先级{xxx} > :xxx > *xxx
// compareRouterPriority compares the priority between <newItem> and <oldItem>. It returns true
// if <newItem>'s priority is higher than <oldItem>, else it returns false. The higher priority
// item will be insert into the router list before the other one.
//
// Comparison rules:
// 1. The middleware has the most high priority.
// 2. URI: The deeper the higher (simply check the count of char '/' in the URI).
// 3. Route type: {xxx} > :xxx > *xxx.
func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerItem) bool {
// 中间件优先级最高,按照添加顺序优先级执行
// If they're all type of middleware, the priority is according their registered sequence.
if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType == gHANDLER_TYPE_MIDDLEWARE {
return false
}
// The middleware has the most high priority.
if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType != gHANDLER_TYPE_MIDDLEWARE {
return true
}
// 优先比较层级,层级越深优先级越高
// URI: The deeper the higher (simply check the count of char '/' in the URI).
if newItem.router.Priority > oldItem.router.Priority {
return true
}
if newItem.router.Priority < oldItem.router.Priority {
return false
}
// 精准匹配比模糊匹配规则优先级高,例如:/name/act 比 /{name}/:act 优先级高
// Route type: {xxx} > :xxx > *xxx.
// Eg: /name/act > /{name}/:act
var fuzzyCountFieldNew, fuzzyCountFieldOld int
var fuzzyCountNameNew, fuzzyCountNameOld int
var fuzzyCountAnyNew, fuzzyCountAnyOld int
@ -230,16 +258,16 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
return false
}
/** 如果模糊规则数量相等,那么执行分别的数量判断 **/
// If the counts of their fuzzy rules equal.
// 例如:/name/{act} 比 /name/:act 优先级高
// Eg: /name/{act} > /name/:act
if fuzzyCountFieldNew > fuzzyCountFieldOld {
return true
}
if fuzzyCountFieldNew < fuzzyCountFieldOld {
return false
}
// 例如: /name/:act 比 /name/*act 优先级高
// Eg: /name/:act > /name/*act
if fuzzyCountNameNew > fuzzyCountNameOld {
return true
}
@ -247,9 +275,10 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
return false
}
/** 比较路由规则长度,越长的规则优先级越高,模糊/命名规则不算长度 **/
// It then compares the length of their URI,
// but the fuzzy and named parts of the URI are not calculated to the result.
// 例如:/admin-goods-{page} 比 /admin-{page} 优先级高
// Eg: /admin-goods-{page} > /admin-{page}
var uriNew, uriOld string
uriNew, _ = gregex.ReplaceString(`\{[^/]+\}`, "", newItem.router.Uri)
uriNew, _ = gregex.ReplaceString(`:[^/]+`, "", uriNew)
@ -264,9 +293,8 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
return false
}
/* 模糊规则数量相等,后续不用再判断*规则的数量比较了 */
// 比较HTTP METHOD更精准的优先级更高
// It then compares the accuracy of their http method,
// the more accurate the more priority.
if newItem.router.Method != gDEFAULT_METHOD {
return true
}
@ -274,23 +302,25 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte
return true
}
// 如果是服务路由,那么新的规则比旧的规则优先级高(路由覆盖)
// If they have different router type,
// the new router item has more priority than the other one.
if newItem.itemType == gHANDLER_TYPE_HANDLER ||
newItem.itemType == gHANDLER_TYPE_OBJECT ||
newItem.itemType == gHANDLER_TYPE_CONTROLLER {
return true
}
// 如果是其他路由(HOOK/中间件),那么新的规则比旧的规则优先级低,使得注册相同路由则顺序执行
// Other situations, like HOOK items,
// the old router item has more priority than the other one.
return false
}
// 将pattern不带method和domain解析成正则表达式匹配以及对应的query字符串
func (s *Server) patternToRegRule(rule string) (regrule string, names []string) {
// patternToRegular converts route rule to according regular expression.
func (s *Server) patternToRegular(rule string) (regular string, names []string) {
if len(rule) < 2 {
return rule, nil
}
regrule = "^"
regular = "^"
array := strings.Split(rule[1:], "/")
for _, v := range array {
if len(v) == 0 {
@ -299,20 +329,20 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
switch v[0] {
case ':':
if len(v) > 1 {
regrule += `/([^/]+)`
regular += `/([^/]+)`
names = append(names, v[1:])
} else {
regrule += `/[^/]+`
regular += `/[^/]+`
}
case '*':
if len(v) > 1 {
regrule += `/{0,1}(.*)`
regular += `/{0,1}(.*)`
names = append(names, v[1:])
} else {
regrule += `/{0,1}.*`
regular += `/{0,1}.*`
}
default:
// 特殊字符替换
// Special chars replacement.
v = gstr.ReplaceByMap(v, map[string]string{
`.`: `\.`,
`+`: `\+`,
@ -323,12 +353,12 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
return `([^/]+)`
})
if strings.EqualFold(s, v) {
regrule += "/" + v
regular += "/" + v
} else {
regrule += "/" + s
regular += "/" + s
}
}
}
regrule += `$`
regular += `$`
return
}

View File

@ -66,8 +66,3 @@ func (s *Server) niceCallHookHandler(f HandlerFunc, r *Request) (err interface{}
f(r)
return
}
// 生成hook key如果是hook key那么使用'%'符号分隔
func (s *Server) handlerKey(hook, method, path, domain string) string {
return hook + "%" + s.serveHandlerKey(method, path, domain)
}

View File

@ -15,18 +15,37 @@ import (
"github.com/gogf/gf/text/gregex"
)
// 缓存数据项
// handlerCacheItem is an item just for internal router searching cache.
type handlerCacheItem struct {
parsedItems []*handlerParsedItem
hasHook bool
hasServe bool
}
// 查询请求处理方法.
// 内部带锁机制可以并发读但是不能并发写并且有缓存机制按照Host、Method、Path进行缓存.
// serveHandlerKey creates and returns a handler key for router.
func (s *Server) serveHandlerKey(method, path, domain string) string {
if len(domain) > 0 {
domain = "@" + domain
}
if method == "" {
return path + strings.ToLower(domain)
}
return strings.ToUpper(method) + ":" + path + strings.ToLower(domain)
}
// getHandlersWithCache searches the router item with cache feature for given request.
func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) {
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} {
parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost())
method := r.Method
// Special http method OPTIONS handling.
// It searches the handler with the request method instead of OPTIONS method.
if method == "OPTIONS" {
if v := r.Request.Header.Get("Access-Control-Request-Method"); v != "" {
method = v
}
}
// Search and cache the router handlers.
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(method, r.URL.Path, r.GetHost()), func() interface{} {
parsedItems, hasHook, hasServe = s.searchHandlers(method, r.URL.Path, r.GetHost())
if parsedItems != nil {
return &handlerCacheItem{parsedItems, hasHook, hasServe}
}
@ -39,18 +58,14 @@ func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedI
return
}
// 路由注册方法检索,返回所有该路由的注册函数,构造成数组返回
// searchHandlers retrieves and returns the routers with given parameters.
// Note that the returned routers contain serving handler, middleware handlers and hook handlers.
func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) {
if len(path) == 0 {
return nil, false, false
}
// 遍历检索的域名列表,优先遍历默认域名
domains := []string{gDEFAULT_DOMAIN}
if !strings.EqualFold(gDEFAULT_DOMAIN, domain) {
domains = append(domains, domain)
}
// URL.Path层级拆分
array := ([]string)(nil)
// Split the URL.path to separate parts.
var array []string
if strings.EqualFold("/", path) {
array = []string{"/"}
} else {
@ -58,85 +73,93 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
}
parsedItemList := glist.New()
lastMiddlewareElem := (*glist.Element)(nil)
repeatHandlerCheckMap := make(map[int]struct{})
for _, domain := range domains {
repeatHandlerCheckMap := make(map[int]struct{}, 16)
// Default domain has the most priority when iteration.
for _, domain := range []string{gDEFAULT_DOMAIN, domain} {
p, ok := s.serveTree[domain]
if !ok {
continue
}
// 多层链表(每个节点都有一个*list链表)的目的是当叶子节点未有任何规则匹配时,让父级模糊匹配规则继续处理
// Make a list array with capacity of 16.
lists := make([]*glist.List, 0, 16)
for k, v := range array {
for i, part := range array {
// In case of double '/' URI, eg: /user//index
if v == "" {
if part == "" {
continue
}
if _, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
// Add all list of each node to the list array.
if v, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, v.(*glist.List))
}
if _, ok := p.(map[string]interface{})[v]; ok {
p = p.(map[string]interface{})[v]
if k == len(array)-1 {
if _, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
if v, ok := p.(map[string]interface{})[part]; ok {
// Loop to the next node by certain key name.
p = v
if i == len(array)-1 {
if v, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, v.(*glist.List))
break
}
}
} else {
if _, ok := p.(map[string]interface{})["*fuzz"]; ok {
p = p.(map[string]interface{})["*fuzz"]
}
} else if v, ok := p.(map[string]interface{})["*fuzz"]; ok {
// Loop to the next node by fuzzy node item.
p = v
}
// 如果是叶子节点,同时判断当前层级的"*fuzz"键名,解决例如:/user/*action 匹配 /user 的规则
if k == len(array)-1 {
if _, ok := p.(map[string]interface{})["*fuzz"]; ok {
p = p.(map[string]interface{})["*fuzz"]
if i == len(array)-1 {
// It here also checks the fuzzy item,
// for rule case like: "/user/*action" matches to "/user".
if v, ok := p.(map[string]interface{})["*fuzz"]; ok {
p = v
}
if _, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
// The leaf must have a list item. It adds the list to the list array.
if v, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, v.(*glist.List))
}
}
}
// 多层链表遍历检索,从数组末尾的链表开始遍历,末尾的深度高优先级也高
// OK, let's loop the result list array, adding the handler item to the result handler result array.
// As the tail of the list array has the most priority, it iterates the list array from its tail to head.
for i := len(lists) - 1; i >= 0; i-- {
for e := lists[i].Front(); e != nil; e = e.Next() {
item := e.Value.(*handlerItem)
// 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数)
// Filter repeated handler item, especially the middleware and hook handlers.
// It is necessary, do not remove this checks logic unless you really know how it is necessary.
if _, ok := repeatHandlerCheckMap[item.itemId]; ok {
continue
} else {
repeatHandlerCheckMap[item.itemId] = struct{}{}
}
// 服务路由函数只能添加一次,将重复判断放在这里提高检索效率
// Serving handler can only be added to the handler array just once.
if hasServe {
switch item.itemType {
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
continue
}
}
// 动态匹配规则带有gDEFAULT_METHOD的情况不会像静态规则那样直接解析为所有的HTTP METHOD存储
if strings.EqualFold(item.router.Method, gDEFAULT_METHOD) || strings.EqualFold(item.router.Method, method) {
// 注意当不带任何动态路由规则时len(match) == 1
if item.router.Method == gDEFAULT_METHOD || item.router.Method == method {
// Note the rule having no fuzzy rules: len(match) == 1
if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 {
parsedItem := &handlerParsedItem{item, nil}
// 如果需要路由规则中带有URI名称匹配那么需要重新正则解析URL
// If the rule contains fuzzy names,
// it needs paring the URL to retrieve the values for the names.
if len(item.router.RegNames) > 0 {
if len(match) > len(item.router.RegNames) {
parsedItem.values = make(map[string]string)
// 如果存在存在同名路由参数名称,那么执行覆盖
// It there repeated names, it just overwrites the same one.
for i, name := range item.router.RegNames {
parsedItem.values[name] = match[i+1]
}
}
}
switch item.itemType {
// 服务路由函数只能添加一次
// The serving handler can be only added just once.
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
hasServe = true
parsedItemList.PushBack(parsedItem)
// 中间件需要排序在链表中服务函数之前,并且多个中间件按照顺序添加以便于后续执行
// The middleware is inserted before the serving handler.
// If there're multiple middlewares, they're inserted into the result list by their registering order.
// The middlewares are also executed by their registering order.
case gHANDLER_TYPE_MIDDLEWARE:
if lastMiddlewareElem == nil {
lastMiddlewareElem = parsedItemList.PushFront(parsedItem)
@ -144,7 +167,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
lastMiddlewareElem = parsedItemList.InsertAfter(lastMiddlewareElem, parsedItem)
}
// 钩子函数存在性判断
// HOOK handler, just push it back to the list.
case gHANDLER_TYPE_HOOK:
hasHook = true
parsedItemList.PushBack(parsedItem)
@ -206,14 +229,3 @@ func (item *handlerItem) MarshalJSON() ([]byte, error) {
func (item *handlerParsedItem) MarshalJSON() ([]byte, error) {
return json.Marshal(item.handler)
}
// 生成回调方法查询的Key
func (s *Server) serveHandlerKey(method, path, domain string) string {
if len(domain) > 0 {
domain = "@" + domain
}
if method == "" {
return path + strings.ToLower(domain)
}
return strings.ToUpper(method) + ":" + path + strings.ToLower(domain)
}

View File

@ -39,7 +39,7 @@ func (s *Server) BindObjectRest(pattern string, object interface{}) {
}
func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc) {
// Convert input method to map for convenience and high performance searching.
// Convert input method to map for convenience and high performance searching purpose.
var methodMap map[string]bool
if len(method) > 0 {
methodMap = make(map[string]bool)
@ -86,12 +86,16 @@ func (s *Server) doBindObject(pattern string, object interface{}, method string,
if !ok {
if len(methodMap) > 0 {
// 指定的方法名称注册,那么需要使用错误提示
s.Logger().Errorf(`invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`,
pkgPath, objName, methodName, v.Method(i).Type().String())
s.Logger().Errorf(
`invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`,
pkgPath, objName, methodName, v.Method(i).Type().String(),
)
} else {
// 否则只是Debug提示
s.Logger().Debugf(`ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`,
pkgPath, objName, methodName, v.Method(i).Type().String())
s.Logger().Debugf(
`ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`,
pkgPath, objName, methodName, v.Method(i).Type().String(),
)
}
continue
}

View File

@ -0,0 +1,43 @@
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp_test
import (
"fmt"
"testing"
"time"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/test/gtest"
)
func Test_Context(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/", func(group *ghttp.RouterGroup) {
group.Middleware(func(r *ghttp.Request) {
r.Context.Set("traceid", 123)
r.Middleware.Next()
})
group.GET("/", func(r *ghttp.Request) {
r.Response.Write(r.Context.Get("traceid"))
})
})
s.SetPort(p)
//s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), `123`)
})
}

View File

@ -594,8 +594,9 @@ func MiddlewareCORS(r *ghttp.Request) {
func Test_Middleware_CORSAndAuth(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Use(MiddlewareCORS)
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
group.Middleware(MiddlewareAuth, MiddlewareCORS)
group.Middleware(MiddlewareAuth)
group.POST("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
@ -680,3 +681,35 @@ func Test_Middleware_Scope(t *testing.T) {
gtest.Assert(client.GetContent("/scope3"), "ae3fb")
})
}
func Test_Middleware_Panic(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
i := 0
s.Group("/", func(group *ghttp.RouterGroup) {
group.Group("/", func(group *ghttp.RouterGroup) {
group.Middleware(func(r *ghttp.Request) {
i++
panic("error")
r.Middleware.Next()
}, func(r *ghttp.Request) {
i++
r.Middleware.Next()
})
group.ALL("/", func(r *ghttp.Request) {
r.Response.Write(i)
})
})
})
s.SetPort(p)
//s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "error")
})
}

View File

@ -0,0 +1,76 @@
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp_test
import (
"fmt"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/test/gtest"
"testing"
"time"
)
func Test_Middleware_CORS(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
group.Middleware(MiddlewareCORS)
group.POST("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
// Common Checks.
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
// GET request does not any route.
resp, err := client.Get("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
resp.Close()
// POST request matches the route and CORS middleware.
resp, err = client.Post("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With")
gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE")
gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*")
gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800")
resp.Close()
})
// OPTIONS GET
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
client.SetHeader("Access-Control-Request-Method", "GET")
resp, err := client.Options("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
gtest.Assert(resp.ReadAllString(), "Not Found")
resp.Close()
})
// OPTIONS POST
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
client.SetHeader("Access-Control-Request-Method", "POST")
resp, err := client.Options("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
resp.Close()
})
}

View File

@ -0,0 +1,176 @@
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp_test
import (
"fmt"
"github.com/gogf/gf/debug/gdebug"
"github.com/gogf/gf/os/gfile"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gstr"
"testing"
"time"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/test/gtest"
)
func Test_Params_File_Single(t *testing.T) {
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
p := ports.PopRand()
s := g.Server(p)
s.BindHandler("/upload/single", func(r *ghttp.Request) {
file := r.GetUploadFile("file")
if file == nil {
r.Response.WriteExit("upload file cannot be empty")
}
if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
r.Response.WriteExit(name)
}
r.Response.WriteExit("upload failed")
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
// normal name
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
dstPath := gfile.Join(dstDirPath, "file1.txt")
content := client.PostContent("/upload/single", g.Map{
"file": "@file:" + srcPath,
})
gtest.AssertNE(content, "")
gtest.AssertNE(content, "upload file cannot be empty")
gtest.AssertNE(content, "upload failed")
gtest.Assert(content, "file1.txt")
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
})
// randomly rename.
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
content := client.PostContent("/upload/single", g.Map{
"file": "@file:" + srcPath,
"randomlyRename": true,
})
dstPath := gfile.Join(dstDirPath, content)
gtest.AssertNE(content, "")
gtest.AssertNE(content, "upload file cannot be empty")
gtest.AssertNE(content, "upload failed")
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
})
}
func Test_Params_File_CustomName(t *testing.T) {
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
p := ports.PopRand()
s := g.Server(p)
s.BindHandler("/upload/single", func(r *ghttp.Request) {
file := r.GetUploadFile("file")
if file == nil {
r.Response.WriteExit("upload file cannot be empty")
}
file.Filename = "my.txt"
if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
r.Response.WriteExit(name)
}
r.Response.WriteExit("upload failed")
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
dstPath := gfile.Join(dstDirPath, "my.txt")
content := client.PostContent("/upload/single", g.Map{
"file": "@file:" + srcPath,
})
gtest.AssertNE(content, "")
gtest.AssertNE(content, "upload file cannot be empty")
gtest.AssertNE(content, "upload failed")
gtest.Assert(content, "my.txt")
gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath))
})
}
func Test_Params_File_Batch(t *testing.T) {
dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr())
p := ports.PopRand()
s := g.Server(p)
s.BindHandler("/upload/batch", func(r *ghttp.Request) {
files := r.GetUploadFiles("file")
if files == nil {
r.Response.WriteExit("upload file cannot be empty")
}
if names, err := files.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil {
r.Response.WriteExit(gstr.Join(names, ","))
}
r.Response.WriteExit("upload failed")
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
// normal name
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
dstPath1 := gfile.Join(dstDirPath, "file1.txt")
dstPath2 := gfile.Join(dstDirPath, "file2.txt")
content := client.PostContent("/upload/batch", g.Map{
"file[0]": "@file:" + srcPath1,
"file[1]": "@file:" + srcPath2,
})
gtest.AssertNE(content, "")
gtest.AssertNE(content, "upload file cannot be empty")
gtest.AssertNE(content, "upload failed")
gtest.Assert(content, "file1.txt,file2.txt")
gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1))
gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2))
})
// randomly rename.
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt")
srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt")
content := client.PostContent("/upload/batch", g.Map{
"file[0]": "@file:" + srcPath1,
"file[1]": "@file:" + srcPath2,
"randomlyRename": true,
})
gtest.AssertNE(content, "")
gtest.AssertNE(content, "upload file cannot be empty")
gtest.AssertNE(content, "upload failed")
array := gstr.SplitAndTrim(content, ",")
gtest.Assert(len(array), 2)
dstPath1 := gfile.Join(dstDirPath, array[0])
dstPath2 := gfile.Join(dstDirPath, array[1])
gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1))
gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2))
})
}

View File

@ -0,0 +1,48 @@
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp_test
import (
"fmt"
"testing"
"time"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/test/gtest"
)
func Test_Params_Page(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/", func(group *ghttp.RouterGroup) {
group.GET("/list", func(r *ghttp.Request) {
page := r.GetPage(5, 2)
r.Response.Write(page.GetContent(4))
})
group.GET("/list/{page}.html", func(r *ghttp.Request) {
page := r.GetPage(5, 2)
r.Response.Write(page.GetContent(4))
})
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/list"), `<span class="GPageSpan">首页</span><span class="GPageSpan">上一页</span><span class="GPageSpan">1</span><a class="GPageLink" href="/list?page=2" title="2">2</a><a class="GPageLink" href="/list?page=3" title="3">3</a><a class="GPageLink" href="/list?page=2" title="">下一页</a><a class="GPageLink" href="/list?page=3" title="">尾页</a>`)
gtest.Assert(client.GetContent("/list?page=3"), `<a class="GPageLink" href="/list?page=1" title="">首页</a><a class="GPageLink" href="/list?page=2" title="">上一页</a><a class="GPageLink" href="/list?page=1" title="1">1</a><a class="GPageLink" href="/list?page=2" title="2">2</a><span class="GPageSpan">3</span><span class="GPageSpan">下一页</span><span class="GPageSpan">尾页</span>`)
gtest.Assert(client.GetContent("/list/1.html"), `<span class="GPageSpan">首页</span><span class="GPageSpan">上一页</span><span class="GPageSpan">1</span><a class="GPageLink" href="/list/2.html" title="2">2</a><a class="GPageLink" href="/list/3.html" title="3">3</a><a class="GPageLink" href="/list/2.html" title="">下一页</a><a class="GPageLink" href="/list/3.html" title="">尾页</a>`)
gtest.Assert(client.GetContent("/list/3.html"), `<a class="GPageLink" href="/list/1.html" title="">首页</a><a class="GPageLink" href="/list/2.html" title="">上一页</a><a class="GPageLink" href="/list/1.html" title="1">1</a><a class="GPageLink" href="/list/2.html" title="2">2</a><span class="GPageSpan">3</span><span class="GPageSpan">下一页</span><span class="GPageSpan">尾页</span>`)
})
}

View File

@ -46,14 +46,6 @@ func (c *ControllerRest) Delete() {
c.Response.Write("Controller Delete")
}
func (c *ControllerRest) Patch() {
c.Response.Write("Controller Patch")
}
func (c *ControllerRest) Options() {
c.Response.Write("Controller Options")
}
func (c *ControllerRest) Head() {
c.Response.Header().Set("head-ok", "1")
}
@ -78,8 +70,6 @@ func Test_Router_ControllerRest(t *testing.T) {
gtest.Assert(client.PutContent("/"), "1Controller Put2")
gtest.Assert(client.PostContent("/"), "1Controller Post2")
gtest.Assert(client.DeleteContent("/"), "1Controller Delete2")
gtest.Assert(client.PatchContent("/"), "1Controller Patch2")
gtest.Assert(client.OptionsContent("/"), "1Controller Options2")
resp1, err := client.Head("/")
if err == nil {
defer resp1.Close()
@ -91,8 +81,6 @@ func Test_Router_ControllerRest(t *testing.T) {
gtest.Assert(client.PutContent("/controller-rest/put"), "1Controller Put2")
gtest.Assert(client.PostContent("/controller-rest/post"), "1Controller Post2")
gtest.Assert(client.DeleteContent("/controller-rest/delete"), "1Controller Delete2")
gtest.Assert(client.PatchContent("/controller-rest/patch"), "1Controller Patch2")
gtest.Assert(client.OptionsContent("/controller-rest/options"), "1Controller Options2")
resp2, err := client.Head("/controller-rest/head")
if err == nil {
defer resp2.Close()

View File

@ -70,7 +70,7 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) {
r.Response.Write(r.Router.Uri)
})
s.SetPort(p)
s.SetDumpRouterMap(false)
//s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()

1
net/ghttp/testdata/upload/file1.txt vendored Normal file
View File

@ -0,0 +1 @@
file1.txt: This file is for uploading unit test case.

1
net/ghttp/testdata/upload/file2.txt vendored Normal file
View File

@ -0,0 +1 @@
file2.txt: This file is for uploading unit test case.

View File

@ -10,6 +10,7 @@ package gfile
import (
"bytes"
"errors"
"github.com/gogf/gf/text/gstr"
"os"
"os/exec"
"os/user"
@ -30,17 +31,32 @@ const (
var (
// Default perm for file opening.
DefaultPerm = os.FileMode(0666)
// The absolute file path for main package.
// It can be only checked and set once.
mainPkgPath = gtype.NewString()
// selfPath is the current running binary path.
// As it is most commonly used, it is so defined as an internal package variable.
selfPath = ""
// Temporary directory of system.
tempDir = "/tmp"
)
func init() {
// Initialize internal package variable: tempDir.
if !Exists(tempDir) {
tempDir = os.TempDir()
}
// Initialize internal package variable: selfPath.
selfPath, _ = exec.LookPath(os.Args[0])
if selfPath != "" {
selfPath, _ = filepath.Abs(selfPath)
}
if selfPath == "" {
selfPath, _ = filepath.Abs(os.Args[0])
}
}
// Mkdir creates directories recursively with given <path>.
@ -58,22 +74,27 @@ func Mkdir(path string) error {
func Create(path string) (*os.File, error) {
dir := Dir(path)
if !Exists(dir) {
Mkdir(dir)
if err := Mkdir(dir); err != nil {
return nil, err
}
}
return os.Create(path)
}
// Open opens file/directory readonly.
// Open opens file/directory READONLY.
func Open(path string) (*os.File, error) {
return os.Open(path)
}
// OpenFile opens file/directory with given <flag> and <perm>.
// OpenFile opens file/directory with custom <flag> and <perm>.
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
func OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) {
return os.OpenFile(path, flag, perm)
}
// OpenWithFlag opens file/directory with default perm and given <flag>.
// OpenWithFlag opens file/directory with default perm and custom <flag>.
// The default <perm> is 0666.
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
func OpenWithFlag(path string, flag int) (*os.File, error) {
f, err := os.OpenFile(path, flag, DefaultPerm)
if err != nil {
@ -82,9 +103,11 @@ func OpenWithFlag(path string, flag int) (*os.File, error) {
return f, nil
}
// OpenWithFlagPerm opens file/directory with given <flag> and <perm>.
// OpenWithFlagPerm opens file/directory with custom <flag> and <perm>.
// The parameter <flag> is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc.
// The parameter <perm> is like: 0600, 0666, 0777, etc.
func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error) {
f, err := os.OpenFile(path, flag, os.FileMode(perm))
f, err := os.OpenFile(path, flag, perm)
if err != nil {
return nil, err
}
@ -93,7 +116,14 @@ func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error)
// Join joins string array paths with file separator of current system.
func Join(paths ...string) string {
return strings.Join(paths, Separator)
var s string
for _, path := range paths {
if s != "" {
s += Separator
}
s += gstr.TrimRight(path, Separator)
}
return s
}
// Exists checks whether given <path> exist.
@ -198,6 +228,7 @@ func Glob(pattern string, onlyNames ...bool) ([]string, error) {
// Remove deletes all file/directory with <path> parameter.
// If parameter <path> is directory, it deletes it recursively.
func Remove(path string) error {
//intlog.Print(`Remove:`, path)
return os.RemoveAll(path)
}
@ -268,17 +299,7 @@ func RealPath(path string) string {
// SelfPath returns absolute file path of current running process(binary).
func SelfPath() string {
path, _ := exec.LookPath(os.Args[0])
if path != "" {
path, _ = filepath.Abs(path)
if path != "" {
return path
}
}
if path == "" {
path, _ = filepath.Abs(os.Args[0])
}
return path
return selfPath
}
// SelfName returns file name of current running process(binary).

View File

@ -380,9 +380,9 @@ func ParseTimeFromContent(content string, format ...string) *Time {
// FuncCost calculates the cost time of function <f> in nanoseconds.
func FuncCost(f func()) int64 {
t := Nanosecond()
t := TimestampNano()
f()
return Nanosecond() - t
return TimestampNano() - t
}
// isNumeric checks whether given <s> is a number.

View File

@ -157,6 +157,7 @@ func (t *Time) Nanosecond() int {
return t.Time.Nanosecond()
}
// String returns current time object as string.
func (t *Time) String() string {
if t == nil {
return ""
@ -167,14 +168,6 @@ func (t *Time) String() string {
return t.Format("Y-m-d H:i:s")
}
// String returns current time object as string.
//func (t Time) String() string {
// if t.IsZero() {
// return ""
// }
// return t.Format("Y-m-d H:i:s")
//}
// Clone returns a new Time object which is a clone of current time object.
func (t *Time) Clone() *Time {
return New(t.Time)

View File

@ -4,16 +4,18 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs running and management.
// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs
// running and management.
//
// This package is designed for management for millions of timing jobs.
// The differences between gtime and gcron are as follows:
// 1. gcron is implemented based on gtimer.
// This package is designed for management for millions of timing jobs. The differences
// between gtimer and gcron are as follows:
// 1. package gcron is implemented based on package gtimer.
// 2. gtimer is designed for high performance and for millions of timing jobs.
// 3. gcron supports pattern grammar like linux crontab.
// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured in microseconds.
// 3. gcron supports configuration pattern grammar like linux crontab, which is more manually readable.
// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured
// in microseconds.
//
// Note the common delay of the timer: https://github.com/golang/go/issues/14410
// ALSO VERY NOTE the common delay of the timer: https://github.com/golang/go/issues/14410
package gtimer
import (
@ -119,8 +121,10 @@ func DelayAddTimes(delay time.Duration, interval time.Duration, times int, job J
defaultTimer.DelayAddTimes(delay, interval, times, job)
}
// Exit is used in timing job, which exits and marks it closed from timer.
// The timing job will be removed from timer later.
// Exit is used in timing job internally, which exits and marks it closed from timer.
// The timing job will be automatically removed from timer later. It uses "panic-recover"
// mechanism internally implementing this feature, which is designed for simplification
// and convenience.
func Exit() {
panic(gPANIC_EXIT)
}

View File

@ -14,7 +14,7 @@ import (
var (
regexMu = sync.RWMutex{}
// Cache for regex object.
// TODO There's no expiring logic for this map.
// Note that there's no expiring logic for this map.
regexMap = make(map[string]*regexp.Regexp)
)
@ -22,29 +22,25 @@ var (
// It uses cache to enhance the performance for compiling regular expression pattern,
// which means, it will return the same *regexp.Regexp object with the same regular
// expression pattern.
func getRegexp(pattern string) (*regexp.Regexp, error) {
if r := getCache(pattern); r != nil {
return r, nil
}
if r, err := regexp.Compile(pattern); err == nil {
setCache(pattern, r)
return r, nil
} else {
return nil, err
}
}
// getCache returns *regexp.Regexp object from cache by given <pattern>, for internal usage.
func getCache(pattern string) (regex *regexp.Regexp) {
//
// It is concurrent-safe for multiple goroutines.
func getRegexp(pattern string) (regex *regexp.Regexp, err error) {
// Retrieve the regular expression object using reading lock.
regexMu.RLock()
regex = regexMap[pattern]
regexMu.RUnlock()
return
}
// setCache stores *regexp.Regexp object into cache, for internal usage.
func setCache(pattern string, regex *regexp.Regexp) {
if regex != nil {
return
}
// If it does not exist in the cache,
// it compiles the pattern and creates one.
regex, err = regexp.Compile(pattern)
if err != nil {
return
}
// Cache the result object using writing lock.
regexMu.Lock()
regexMap[pattern] = regex
regexMu.Unlock()
return
}

View File

@ -465,16 +465,16 @@ func SplitAndTrimSpace(str, delimiter string) []string {
return array
}
// Join concatenates the elements of a to create a single string. The separator string
// sep is placed between elements in the resulting string.
// Join concatenates the elements of <array> to create a single string. The separator string
// <sep> is placed between elements in the resulting string.
func Join(array []string, sep string) string {
return strings.Join(array, sep)
}
// JoinAny concatenates the elements of a to create a single string. The separator string
// sep is placed between elements in the resulting string.
// JoinAny concatenates the elements of <array> to create a single string. The separator string
// <sep> is placed between elements in the resulting string.
//
// The parameter <array> can be any type of slice.
// The parameter <array> can be any type of slice, which be converted to string array.
func JoinAny(array interface{}, sep string) string {
return strings.Join(gconv.Strings(array), sep)
}

View File

@ -48,7 +48,8 @@ var (
)
// Convert converts the variable <i> to the type <t>, the type <t> is specified by string.
// The optional parameter <params> is used for additional parameter passing.
// The optional parameter <params> is used for additional necessary parameter for this conversion.
// It supports common types conversion as its conversion based on type name string.
func Convert(i interface{}, t string, params ...interface{}) interface{} {
switch t {
case "int":
@ -121,6 +122,7 @@ func Convert(i interface{}, t string, params ...interface{}) interface{} {
case "Duration", "time.Duration":
return Duration(i)
default:
return i
}
}

View File

@ -40,6 +40,7 @@ func MapDeep(value interface{}, tags ...string) map[string]interface{} {
}
// doMapConvert implements the map converting.
// It automatically checks and converts json string to map if <value> is string/[]byte.
func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]interface{} {
if value == nil {
return nil
@ -122,7 +123,7 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]
for k, v := range r {
m[String(k)] = v
}
// Not a common type, then use reflection.
// Not a common type, it then uses reflection for conversion.
default:
rv := reflect.ValueOf(value)
kind := rv.Kind()
@ -132,6 +133,20 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string]
kind = rv.Kind()
}
switch kind {
// If <value> is type of array, it converts the value of even number index as its key and
// the value of odd number index as its corresponding value.
// Eg:
// []string{"k1","v1","k2","v2"} => map[string]interface{}{"k1":"v1", "k2":"v2"}
// []string{"k1","v1","k2"} => map[string]interface{}{"k1":"v1", "k2":nil}
case reflect.Slice, reflect.Array:
length := rv.Len()
for i := 0; i < length; i += 2 {
if i+1 < length {
m[String(rv.Index(i).Interface())] = rv.Index(i + 1).Interface()
} else {
m[String(rv.Index(i).Interface())] = nil
}
}
case reflect.Map:
ks := rv.MapKeys()
for _, k := range ks {

View File

@ -242,6 +242,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
if !structFieldValue.CanSet() {
return nil
}
// If any panic, it secondly uses reflect conversion and assignment.
defer func() {
if recover() != nil {
err = bindVarToReflectValue(structFieldValue, value)
@ -250,6 +251,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
if empty.IsNil(value) {
structFieldValue.Set(reflect.Zero(structFieldValue.Type()))
} else {
// It firstly simply assigns the value to the attribute.
structFieldValue.Set(reflect.ValueOf(Convert(value, structFieldValue.Type().String())))
}
return nil
@ -260,7 +262,8 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
switch structFieldValue.Kind() {
case reflect.Struct:
if err := Struct(value, structFieldValue); err != nil {
structFieldValue.Set(reflect.ValueOf(value))
// Note there's reflect conversion mechanism here.
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
}
// Note that the slice element might be type of struct,
// so it uses Struct function doing the converting internally.
@ -275,13 +278,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
if t.Kind() == reflect.Ptr {
e := reflect.New(t.Elem()).Elem()
if err := Struct(v.Index(i).Interface(), e); err != nil {
e.Set(reflect.ValueOf(v.Index(i).Interface()))
// Note there's reflect conversion mechanism here.
e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t))
}
a.Index(i).Set(e.Addr())
} else {
e := reflect.New(t).Elem()
if err := Struct(v.Index(i).Interface(), e); err != nil {
e.Set(reflect.ValueOf(v.Index(i).Interface()))
// Note there's reflect conversion mechanism here.
e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t))
}
a.Index(i).Set(e)
}
@ -293,13 +298,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
if t.Kind() == reflect.Ptr {
e := reflect.New(t.Elem()).Elem()
if err := Struct(value, e); err != nil {
e.Set(reflect.ValueOf(value))
// Note there's reflect conversion mechanism here.
e.Set(reflect.ValueOf(value).Convert(t))
}
a.Index(0).Set(e.Addr())
} else {
e := reflect.New(t).Elem()
if err := Struct(value, e); err != nil {
e.Set(reflect.ValueOf(value))
// Note there's reflect conversion mechanism here.
e.Set(reflect.ValueOf(value).Convert(t))
}
a.Index(0).Set(e)
}
@ -311,34 +318,40 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e
// Assign value with interface Set.
// Note that only pointer can implement interface Set.
if v, ok := item.Interface().(apiUnmarshalValue); ok {
v.UnmarshalValue(value)
err = v.UnmarshalValue(value)
structFieldValue.Set(item)
return nil
return err
}
elem := item.Elem()
if err = bindVarToReflectValue(elem, value); err == nil {
structFieldValue.Set(elem.Addr())
}
// It mainly and specially handles the interface of nil value.
case reflect.Interface:
if value == nil {
// Specially.
structFieldValue.Set(reflect.ValueOf((*interface{})(nil)))
} else {
structFieldValue.Set(reflect.ValueOf(value))
// Note there's reflect conversion mechanism here.
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
}
default:
defer func() {
if e := recover(); e != nil {
err = errors.New(
fmt.Sprintf(`cannot convert "%d" to type "%s"`,
fmt.Sprintf(`cannot convert value "%d" to type "%s"`,
value,
structFieldValue.Type().String(),
),
)
}
}()
structFieldValue.Set(reflect.ValueOf(value))
// It here uses reflect converting <value> to type of the attribute and assigns
// the result value to the attribute. It might fail and panic if the usual Go
// conversion rules do not allow conversion.
structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type()))
}
return nil
}

View File

@ -37,6 +37,23 @@ func Test_Map_Basic(t *testing.T) {
})
}
func Test_Map_Slice(t *testing.T) {
gtest.Case(t, func() {
slice1 := g.Slice{"1", "2", "3", "4"}
slice2 := g.Slice{"1", "2", "3"}
slice3 := g.Slice{}
gtest.Assert(gconv.Map(slice1), g.Map{
"1": "2",
"3": "4",
})
gtest.Assert(gconv.Map(slice2), g.Map{
"1": "2",
"3": nil,
})
gtest.Assert(gconv.Map(slice3), g.Map{})
})
}
func Test_Map_StructWithGconvTag(t *testing.T) {
gtest.Case(t, func() {
type User struct {

View File

@ -324,6 +324,36 @@ func Test_Struct_Attr_Struct_Slice_Ptr(t *testing.T) {
})
}
func Test_Struct_Attr_CustomType1(t *testing.T) {
type MyInt int
type User struct {
Id MyInt
Name string
}
gtest.Case(t, func() {
user := new(User)
err := gconv.Struct(g.Map{"id": 1, "name": "john"}, user)
gtest.Assert(err, nil)
gtest.Assert(user.Id, 1)
gtest.Assert(user.Name, "john")
})
}
func Test_Struct_Attr_CustomType2(t *testing.T) {
type MyInt int
type User struct {
Id []MyInt
Name string
}
gtest.Case(t, func() {
user := new(User)
err := gconv.Struct(g.Map{"id": g.Slice{1, 2}, "name": "john"}, user)
gtest.Assert(err, nil)
gtest.Assert(user.Id, g.Slice{1, 2})
gtest.Assert(user.Name, "john")
})
}
func Test_Struct_PrivateAttribute(t *testing.T) {
type User struct {
Id int

View File

@ -9,292 +9,217 @@ package gpage
import (
"fmt"
"math"
url2 "net/url"
"strings"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"math"
)
// 分页对象
// Page is the pagination implementer.
// All the attributes are public, you can change them when necessary.
type Page struct {
Url *url2.URL // 当前页面的URL对象
Router *ghttp.Router // 当前页面的路由对象(与gf框架耦合在静态分页下有效)
UrlTemplate string // URL生成规则内部可使用{.page}变量指定页码
TotalSize int // 总共数据条数
TotalPage int // 总页数
CurrentPage int // 当前页码
PageName string // 分页参数名称(GET参数)
NextPageTag string // 下一页标签
PrevPageTag string // 上一页标签
FirstPageTag string // 首页标签
LastPageTag string // 尾页标签
PrevBar string // 上一分页条
NextBar string // 下一分页条
PageBarNum int // 控制分页条的数量
AjaxActionName string // AJAX方法名当该属性有值时表示使用AJAX分页
TotalSize int // Total size.
TotalPage int // Total page, which is automatically calculated.
CurrentPage int // Current page number >= 1.
UrlTemplate string // Custom url template for page url producing.
LinkStyle string // CSS style name for HTML link tag <a>.
SpanStyle string // CSS style name for HTML span tag <span>, which is used for first, current and last page tag.
SelectStyle string // CSS style name for HTML select tag <select>.
NextPageTag string // Tag name for next p.
PrevPageTag string // Tag name for prev p.
FirstPageTag string // Tag name for first p.
LastPageTag string // Tag name for last p.
PrevBarTag string // Tag string for prev bar.
NextBarTag string // Tag string for next bar.
PageBarNum int // Page bar number for displaying.
AjaxActionName string // Ajax function name. Ajax is enabled if this attribute is not empty.
}
// 创建一个分页对象,输入参数分别为:
// 总数量、每页数量、当前页码、当前的URL(URI+QUERY)、(可选)路由规则(例如: /user/list/:page、/order/list/*page、/order/list/{page}.html)
func New(TotalSize, perPage int, CurrentPage interface{}, url string, router ...*ghttp.Router) *Page {
u, _ := url2.Parse(url)
page := &Page{
PageName: "page",
const (
PAGE_NAME = "page" // PAGE_NAME defines the default page name.
PAGE_PLACE_HOLDER = "{.page}" // PAGE_PLACE_HOLDER defines the place holder for the url template.
)
// New creates and returns a pagination manager.
// Note that the parameter <urlTemplate> specifies the URL producing template, like:
// /user/list/{.page}, /user/list/{.page}.html, /user/list?page={.page}&type=1, etc.
// The build-in variable in <urlTemplate> "{.page}" specifies the page number, which will be replaced by certain
// page number when producing.
func New(totalSize, pageSize, currentPage int, urlTemplate string) *Page {
p := &Page{
LinkStyle: "GPageLink",
SpanStyle: "GPageSpan",
SelectStyle: "GPageSelect",
PrevPageTag: "<",
NextPageTag: ">",
FirstPageTag: "|<",
LastPageTag: ">|",
PrevBar: "<<",
NextBar: ">>",
TotalSize: TotalSize,
TotalPage: int(math.Ceil(float64(TotalSize) / float64(perPage))),
CurrentPage: 1,
PrevBarTag: "<<",
NextBarTag: ">>",
TotalSize: totalSize,
TotalPage: int(math.Ceil(float64(totalSize) / float64(pageSize))),
CurrentPage: currentPage,
PageBarNum: 10,
Url: u,
UrlTemplate: urlTemplate,
}
curPage := gconv.Int(CurrentPage)
if curPage > 0 {
page.CurrentPage = curPage
if currentPage == 0 {
p.CurrentPage = 1
}
if len(router) > 0 {
page.Router = router[0]
}
return page
return p
}
// 启用AJAX分页
func (page *Page) EnableAjax(actionName string) {
page.AjaxActionName = actionName
// NextPage returns the HTML content for the next page.
func (p *Page) NextPage() string {
if p.CurrentPage < p.TotalPage {
return p.GetLink(p.CurrentPage+1, p.NextPageTag, "")
}
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.NextPageTag)
}
// 设置URL生成规则模板模板中可使用{.page}变量指定页码位置
func (page *Page) SetUrlTemplate(template string) {
page.UrlTemplate = template
// PrevPage returns the HTML content for the previous page.
func (p *Page) PrevPage() string {
if p.CurrentPage > 1 {
return p.GetLink(p.CurrentPage-1, p.PrevPageTag, "")
}
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.PrevPageTag)
}
// 获取显示"下一页"的内容.
func (page *Page) NextPage(styles ...string) string {
var curStyle, style string
if len(styles) > 0 {
curStyle = styles[0]
// FirstPage returns the HTML content for the first page.
func (p *Page) FirstPage() string {
if p.CurrentPage == 1 {
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.FirstPageTag)
}
if len(styles) > 1 {
style = styles[0]
}
if page.CurrentPage < page.TotalPage {
return page.GetLink(page.GetUrl(page.CurrentPage+1), page.NextPageTag, "下一页", style)
}
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.NextPageTag)
return p.GetLink(1, p.FirstPageTag, "")
}
// 获取显示“上一页”的内容
func (page *Page) PrevPage(styles ...string) string {
var curStyle, style string
if len(styles) > 0 {
curStyle = styles[0]
// LastPage returns the HTML content for the last page.
func (p *Page) LastPage() string {
if p.CurrentPage == p.TotalPage {
return fmt.Sprintf(`<span class="%s">%s</span>`, p.SpanStyle, p.LastPageTag)
}
if len(styles) > 1 {
style = styles[0]
}
if page.CurrentPage > 1 {
return page.GetLink(page.GetUrl(page.CurrentPage-1), page.PrevPageTag, "上一页", style)
}
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.PrevPageTag)
return p.GetLink(p.TotalPage, p.LastPageTag, "")
}
// 获取显示“首页”的代码
func (page *Page) FirstPage(styles ...string) string {
var curStyle, style string
if len(styles) > 0 {
curStyle = styles[0]
// PageBar returns the HTML page bar content with link and span tags.
func (p *Page) PageBar() string {
plus := int(math.Ceil(float64(p.PageBarNum / 2)))
if p.PageBarNum-plus+p.CurrentPage > p.TotalPage {
plus = p.PageBarNum - p.TotalPage + p.CurrentPage
}
if len(styles) > 1 {
style = styles[0]
}
if page.CurrentPage == 1 {
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.FirstPageTag)
}
return page.GetLink(page.GetUrl(1), page.FirstPageTag, "第一页", style)
}
// 获取显示“尾页”的内容
func (page *Page) LastPage(styles ...string) string {
var curStyle, style string
if len(styles) > 0 {
curStyle = styles[0]
}
if len(styles) > 1 {
style = styles[0]
}
if page.CurrentPage == page.TotalPage {
return fmt.Sprintf(`<span class="%s">%s</span>`, curStyle, page.LastPageTag)
}
return page.GetLink(page.GetUrl(page.TotalPage), page.LastPageTag, "最后页", style)
}
// 获得分页条列表内容
func (page *Page) PageBar(styles ...string) string {
var curStyle, style string
if len(styles) > 0 {
curStyle = styles[0]
}
if len(styles) > 1 {
style = styles[0]
}
plus := int(math.Ceil(float64(page.PageBarNum / 2)))
if page.PageBarNum-plus+page.CurrentPage > page.TotalPage {
plus = page.PageBarNum - page.TotalPage + page.CurrentPage
}
begin := page.CurrentPage - plus + 1
begin := p.CurrentPage - plus + 1
if begin < 1 {
begin = 1
}
ret := ""
for i := begin; i < begin+page.PageBarNum; i++ {
if i <= page.TotalPage {
if i != page.CurrentPage {
ret += page.GetLink(page.GetUrl(i), gconv.String(i), style, "")
barContent := ""
for i := begin; i < begin+p.PageBarNum; i++ {
if i <= p.TotalPage {
if i != p.CurrentPage {
barText := gconv.String(i)
barContent += p.GetLink(i, barText, barText)
} else {
ret += fmt.Sprintf(`<span class="%s">%d</span>`, curStyle, i)
barContent += fmt.Sprintf(`<span class="%s">%d</span>`, p.SpanStyle, i)
}
} else {
break
}
}
return ret
return barContent
}
// 获取基于select标签的显示跳转按钮的代码
func (page *Page) SelectBar() string {
ret := `<select name="gpage_select" onchange="window.location.href=this.value">`
for i := 1; i <= page.TotalPage; i++ {
if i == page.CurrentPage {
ret += fmt.Sprintf(`<option value="%s" selected>%d</option>`, page.GetUrl(i), i)
// SelectBar returns the select HTML content for pagination.
func (p *Page) SelectBar() string {
barContent := fmt.Sprintf(`<select name="%s" onchange="window.location.href=this.value">`, p.SelectStyle)
for i := 1; i <= p.TotalPage; i++ {
if i == p.CurrentPage {
barContent += fmt.Sprintf(`<option value="%s" selected>%d</option>`, p.GetUrl(i), i)
} else {
ret += fmt.Sprintf(`<option value="%s">%d</option>`, page.GetUrl(i), i)
barContent += fmt.Sprintf(`<option value="%s">%d</option>`, p.GetUrl(i), i)
}
}
ret += "</select>"
return ret
barContent += "</select>"
return barContent
}
// 预定义的分页显示风格内容
func (page *Page) GetContent(mode int) string {
// GetContent returns the page content for predefined mode.
// These predefined contents are mainly for chinese localization purpose. You can defines your own
// page function retrieving the page content according to the implementation of this function.
func (p *Page) GetContent(mode int) string {
switch mode {
case 1:
page.NextPageTag = "下一页"
page.PrevPageTag = "上一页"
p.NextPageTag = "下一页"
p.PrevPageTag = "上一页"
return fmt.Sprintf(
`%s <span class="current">%d</span> %s`,
page.PrevPage(),
page.CurrentPage,
page.NextPage(),
p.PrevPage(),
p.CurrentPage,
p.NextPage(),
)
case 2:
page.NextPageTag = "下一页>>"
page.PrevPageTag = "<<上一页"
page.FirstPageTag = "首页"
page.LastPageTag = "尾页"
p.NextPageTag = "下一页>>"
p.PrevPageTag = "<<上一页"
p.FirstPageTag = "首页"
p.LastPageTag = "尾页"
return fmt.Sprintf(
`%s%s<span class="current">[第%d页]</span>%s%s第%s页`,
page.FirstPage(),
page.PrevPage(),
page.CurrentPage,
page.NextPage(),
page.LastPage(),
page.SelectBar(),
p.FirstPage(),
p.PrevPage(),
p.CurrentPage,
p.NextPage(),
p.LastPage(),
p.SelectBar(),
)
case 3:
page.NextPageTag = "下一页"
page.PrevPageTag = "上一页"
page.FirstPageTag = "首页"
page.LastPageTag = "尾页"
pageStr := page.FirstPage()
pageStr += page.PrevPage()
pageStr += page.PageBar("current")
pageStr += page.NextPage()
pageStr += page.LastPage()
p.NextPageTag = "下一页"
p.PrevPageTag = "上一页"
p.FirstPageTag = "首页"
p.LastPageTag = "尾页"
pageStr := p.FirstPage()
pageStr += p.PrevPage()
pageStr += p.PageBar()
pageStr += p.NextPage()
pageStr += p.LastPage()
pageStr += fmt.Sprintf(
`<span>当前页%d/%d</span> <span>共%d条</span>`,
page.CurrentPage,
page.TotalPage,
page.TotalSize,
p.CurrentPage,
p.TotalPage,
p.TotalSize,
)
return pageStr
case 4:
page.NextPageTag = "下一页"
page.PrevPageTag = "上一页"
page.FirstPageTag = "首页"
page.LastPageTag = "尾页"
pageStr := page.FirstPage()
pageStr += page.PrevPage()
pageStr += page.PageBar("current")
pageStr += page.NextPage()
pageStr += page.LastPage()
p.NextPageTag = "下一页"
p.PrevPageTag = "上一页"
p.FirstPageTag = "首页"
p.LastPageTag = "尾页"
pageStr := p.FirstPage()
pageStr += p.PrevPage()
pageStr += p.PageBar()
pageStr += p.NextPage()
pageStr += p.LastPage()
return pageStr
}
return ""
}
// 为指定的页面返回地址值
func (page *Page) GetUrl(pageNo int) string {
// 复制一个URL对象
url := *page.Url
if len(page.UrlTemplate) == 0 && page.Router != nil {
page.UrlTemplate = page.makeUrlTemplate(url.Path, page.Router)
}
if len(page.UrlTemplate) > 0 {
// 指定URL生成模板
url.Path = gstr.Replace(page.UrlTemplate, "{.page}", gconv.String(pageNo))
return url.String()
}
values := page.Url.Query()
values.Set(page.PageName, gconv.String(pageNo))
url.RawQuery = values.Encode()
return url.String()
// GetUrl parses the UrlTemplate with given page number and returns the URL string.
// Note that the UrlTemplate attribute can be either an URL or a URI string with "{.page}"
// place holder specifying the page number position.
func (p *Page) GetUrl(page int) string {
return gstr.Replace(p.UrlTemplate, PAGE_PLACE_HOLDER, gconv.String(page))
}
// 根据当前URL以及注册路由信息计算出对应的URL模板
func (page *Page) makeUrlTemplate(url string, router *ghttp.Router) (tpl string) {
if page.Router != nil && len(router.RegNames) > 0 {
if match, err := gregex.MatchString(router.RegRule, url); err == nil && len(match) > 0 {
if len(match) > len(router.RegNames) {
tpl = router.Uri
hasPageName := false
for i, name := range router.RegNames {
rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name)
if !hasPageName && strings.Compare(name, page.PageName) == 0 {
hasPageName = true
tpl, _ = gregex.ReplaceString(rule, `{.page}`, tpl)
} else {
tpl, _ = gregex.ReplaceString(rule, match[i+1], tpl)
}
}
if !hasPageName {
tpl = ""
}
}
}
}
return
}
// 获取链接地址
func (page *Page) GetLink(url, text, title, style string) string {
if len(style) > 0 {
style = fmt.Sprintf(`class="%s" `, style)
}
if len(page.AjaxActionName) > 0 {
return fmt.Sprintf(`<a %shref='#' onclick="%s('%s')">%s</a>`, style, page.AjaxActionName, url, text)
// GetLink returns the HTML link tag <a> content for given page number.
func (p *Page) GetLink(page int, text, title string) string {
if len(p.AjaxActionName) > 0 {
return fmt.Sprintf(
`<a class="%s" href="javascript:%s('%s')" title="%s">%s</a>`,
p.LinkStyle, p.AjaxActionName, p.GetUrl(page), title, text,
)
} else {
return fmt.Sprintf(`<a %shref="%s" title="%s">%s</a>`, style, url, title, text)
return fmt.Sprintf(
`<a class="%s" href="%s" title="%s">%s</a>`,
p.LinkStyle, p.GetUrl(page), title, text,
)
}
}

View File

@ -0,0 +1,116 @@
// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// go test *.go -bench=".*"
package gpage_test
import (
"github.com/gogf/gf/util/gpage"
"testing"
"github.com/gogf/gf/test/gtest"
)
func Test_New(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(9, 2, 1, `/user/list?page={.page}`)
gtest.Assert(page.TotalSize, 9)
gtest.Assert(page.TotalPage, 5)
gtest.Assert(page.CurrentPage, 1)
})
gtest.Case(t, func() {
page := gpage.New(9, 2, 0, `/user/list?page={.page}`)
gtest.Assert(page.TotalSize, 9)
gtest.Assert(page.TotalPage, 5)
gtest.Assert(page.CurrentPage, 1)
})
}
func Test_Basic(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(9, 2, 1, `/user/list?page={.page}`)
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list?page=2" title="">></a>`)
gtest.Assert(page.PrevPage(), `<span class="GPageSpan"><</span>`)
gtest.Assert(page.FirstPage(), `<span class="GPageSpan">|<</span>`)
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list?page=5" title="">>|</a>`)
gtest.Assert(page.PageBar(), `<span class="GPageSpan">1</span><a class="GPageLink" href="/user/list?page=2" title="2">2</a><a class="GPageLink" href="/user/list?page=3" title="3">3</a><a class="GPageLink" href="/user/list?page=4" title="4">4</a><a class="GPageLink" href="/user/list?page=5" title="5">5</a>`)
})
gtest.Case(t, func() {
page := gpage.New(9, 2, 3, `/user/list?page={.page}`)
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list?page=4" title="">></a>`)
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list?page=2" title=""><</a>`)
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list?page=1" title="">|<</a>`)
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list?page=5" title="">>|</a>`)
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list?page=1" title="1">1</a><a class="GPageLink" href="/user/list?page=2" title="2">2</a><span class="GPageSpan">3</span><a class="GPageLink" href="/user/list?page=4" title="4">4</a><a class="GPageLink" href="/user/list?page=5" title="5">5</a>`)
})
gtest.Case(t, func() {
page := gpage.New(9, 2, 5, `/user/list?page={.page}`)
gtest.Assert(page.NextPage(), `<span class="GPageSpan">></span>`)
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list?page=4" title=""><</a>`)
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list?page=1" title="">|<</a>`)
gtest.Assert(page.LastPage(), `<span class="GPageSpan">>|</span>`)
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list?page=1" title="1">1</a><a class="GPageLink" href="/user/list?page=2" title="2">2</a><a class="GPageLink" href="/user/list?page=3" title="3">3</a><a class="GPageLink" href="/user/list?page=4" title="4">4</a><span class="GPageSpan">5</span>`)
})
}
func Test_CustomTag(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
page.PrevPageTag = "《"
page.NextPageTag = "》"
page.FirstPageTag = "|《"
page.LastPageTag = "》|"
page.PrevBarTag = "《《"
page.NextBarTag = "》》"
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="/user/list/3" title="">》</a>`)
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="/user/list/1" title="">《</a>`)
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="/user/list/1" title="">|《</a>`)
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="/user/list/5" title="">》|</a>`)
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="/user/list/1" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="/user/list/3" title="3">3</a><a class="GPageLink" href="/user/list/4" title="4">4</a><a class="GPageLink" href="/user/list/5" title="5">5</a>`)
})
}
func Test_CustomStyle(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
page.LinkStyle = "MyPageLink"
page.SpanStyle = "MyPageSpan"
page.SelectStyle = "MyPageSelect"
gtest.Assert(page.NextPage(), `<a class="MyPageLink" href="/user/list/3" title="">></a>`)
gtest.Assert(page.PrevPage(), `<a class="MyPageLink" href="/user/list/1" title=""><</a>`)
gtest.Assert(page.FirstPage(), `<a class="MyPageLink" href="/user/list/1" title="">|<</a>`)
gtest.Assert(page.LastPage(), `<a class="MyPageLink" href="/user/list/5" title="">>|</a>`)
gtest.Assert(page.PageBar(), `<a class="MyPageLink" href="/user/list/1" title="1">1</a><span class="MyPageSpan">2</span><a class="MyPageLink" href="/user/list/3" title="3">3</a><a class="MyPageLink" href="/user/list/4" title="4">4</a><a class="MyPageLink" href="/user/list/5" title="5">5</a>`)
gtest.Assert(page.SelectBar(), `<select name="MyPageSelect" onchange="window.location.href=this.value"><option value="/user/list/1">1</option><option value="/user/list/2" selected>2</option><option value="/user/list/3">3</option><option value="/user/list/4">4</option><option value="/user/list/5">5</option></select>`)
})
}
func Test_Ajax(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
page.AjaxActionName = "LoadPage"
gtest.Assert(page.NextPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">></a>`)
gtest.Assert(page.PrevPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title=""><</a>`)
gtest.Assert(page.FirstPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">|<</a>`)
gtest.Assert(page.LastPage(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">>|</a>`)
gtest.Assert(page.PageBar(), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a>`)
})
}
func Test_PredefinedContent(t *testing.T) {
gtest.Case(t, func() {
page := gpage.New(5, 1, 2, `/user/list/{.page}`)
page.AjaxActionName = "LoadPage"
gtest.Assert(page.GetContent(1), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a> <span class="current">2</span> <a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a>`)
gtest.Assert(page.GetContent(2), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title=""><<上一页</a><span class="current">[第2页]</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页>></a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a>第<select name="GPageSelect" onchange="window.location.href=this.value"><option value="/user/list/1">1</option><option value="/user/list/2" selected>2</option><option value="/user/list/3">3</option><option value="/user/list/4">4</option><option value="/user/list/5">5</option></select>页`)
gtest.Assert(page.GetContent(3), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a><span>当前页2/5</span> <span>共5条</span>`)
gtest.Assert(page.GetContent(4), `<a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">首页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="">上一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/1')" title="1">1</a><span class="GPageSpan">2</span><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="3">3</a><a class="GPageLink" href="javascript:LoadPage('/user/list/4')" title="4">4</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="5">5</a><a class="GPageLink" href="javascript:LoadPage('/user/list/3')" title="">下一页</a><a class="GPageLink" href="javascript:LoadPage('/user/list/5')" title="">尾页</a>`)
gtest.Assert(page.GetContent(5), ``)
})
}