From a4f191c1c60e5542a95fd5d954b2e38c6ad312b4 Mon Sep 17 00:00:00 2001 From: John Date: Tue, 19 Nov 2019 21:50:17 +0800 Subject: [PATCH] fix issue in gdb.Model for repeated condition statements; remove concurrent safety feature of gview; add default template file feature for gview --- database/gdb/gdb_model.go | 59 +++++++++++---------- database/gdb/gdb_unit_z_mysql_model_test.go | 51 ++++++++++++++++++ os/gview/gview.go | 18 ++++--- os/gview/gview_config.go | 33 +++++------- os/gview/gview_doparse.go | 10 ++-- 5 files changed, 113 insertions(+), 58 deletions(-) diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index 9e76e703f..238a46f4d 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -31,7 +31,6 @@ type Model struct { tables string // 数据库操作表 fields string // 操作字段 fieldsEx string // 操作字段(排除) - where string // 操作条件 whereArgs []interface{} // 操作条件参数 whereHolder []*whereHolder // 操作条件预处理 groupBy string // 分组语句 @@ -550,12 +549,13 @@ func (md *Model) Update() (result sql.Result, err error) { if md.data == nil { return nil, errors.New("updating table with empty data") } + condition, conditionArgs := md.formatCondition() return md.db.doUpdate( md.getLink(), md.tables, md.filterDataForInsertOrUpdate(md.data), - md.getConditionSql(), - md.whereArgs..., + condition, + conditionArgs..., ) } @@ -566,7 +566,8 @@ func (md *Model) Delete() (result sql.Result, err error) { md.checkAndRemoveCache() } }() - return md.db.doDelete(md.getLink(), md.tables, md.getConditionSql(), md.whereArgs...) + condition, conditionArgs := md.formatCondition() + return md.db.doDelete(md.getLink(), md.tables, condition, conditionArgs...) } // 链式操作,select @@ -576,7 +577,8 @@ func (md *Model) Select() (Result, error) { // 链式操作,查询所有记录 func (md *Model) All() (Result, error) { - return md.getAll(fmt.Sprintf("SELECT %s FROM %s%s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...) + condition, conditionArgs := md.formatCondition() + return md.getAll(fmt.Sprintf("SELECT %s FROM %s%s", md.fields, md.tables, condition), conditionArgs...) } // 链式操作,查询单条记录 @@ -651,11 +653,12 @@ func (md *Model) Count() (int, error) { } else { md.fields = fmt.Sprintf(`COUNT(%s)`, md.fields) } - s := fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()) + condition, conditionArgs := md.formatCondition() + s := fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, condition) if len(md.groupBy) > 0 { s = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", s) } - list, err := md.getAll(s, md.whereArgs...) + list, err := md.getAll(s, conditionArgs...) if err != nil { return 0, err } @@ -716,16 +719,17 @@ func (md *Model) checkAndRemoveCache() { } // 格式化当前输入参数,返回SQL条件语句(不带参数) -func (md *Model) getConditionSql() string { +func (md *Model) formatCondition() (condition string, conditionArgs []interface{}) { + var where string if len(md.whereHolder) > 0 { for _, v := range md.whereHolder { switch v.operator { case gWHERE_HOLDER_WHERE: - if md.where == "" { + if where == "" { newWhere, newArgs := formatWhere(md.db, v.where, v.args, md.option&OPTION_OMITEMPTY > 0) if len(newWhere) > 0 { - md.where = newWhere - md.whereArgs = newArgs + where = newWhere + conditionArgs = newArgs } continue } @@ -734,48 +738,47 @@ func (md *Model) getConditionSql() string { case gWHERE_HOLDER_AND: newWhere, newArgs := formatWhere(md.db, v.where, v.args, md.option&OPTION_OMITEMPTY > 0) if len(newWhere) > 0 { - if md.where[0] == '(' { - md.where = fmt.Sprintf(`%s AND (%s)`, md.where, newWhere) + if where[0] == '(' { + where = fmt.Sprintf(`%s AND (%s)`, where, newWhere) } else { - md.where = fmt.Sprintf(`(%s) AND (%s)`, md.where, newWhere) + where = fmt.Sprintf(`(%s) AND (%s)`, where, newWhere) } - md.whereArgs = append(md.whereArgs, newArgs...) + conditionArgs = append(conditionArgs, newArgs...) } case gWHERE_HOLDER_OR: newWhere, newArgs := formatWhere(md.db, v.where, v.args, md.option&OPTION_OMITEMPTY > 0) if len(newWhere) > 0 { - if md.where[0] == '(' { - md.where = fmt.Sprintf(`%s OR (%s)`, md.where, newWhere) + if where[0] == '(' { + where = fmt.Sprintf(`%s OR (%s)`, where, newWhere) } else { - md.where = fmt.Sprintf(`(%s) OR (%s)`, md.where, newWhere) + where = fmt.Sprintf(`(%s) OR (%s)`, where, newWhere) } - md.whereArgs = append(md.whereArgs, newArgs...) + conditionArgs = append(conditionArgs, newArgs...) } } } } - s := "" - if md.where != "" { - s += " WHERE " + md.where + if where != "" { + condition += " WHERE " + where } if md.groupBy != "" { - s += " GROUP BY " + md.groupBy + condition += " GROUP BY " + md.groupBy } if md.orderBy != "" { - s += " ORDER BY " + md.orderBy + condition += " ORDER BY " + md.orderBy } if md.limit != 0 { if md.start >= 0 { - s += fmt.Sprintf(" LIMIT %d,%d", md.start, md.limit) + condition += fmt.Sprintf(" LIMIT %d,%d", md.start, md.limit) } else { - s += fmt.Sprintf(" LIMIT %d", md.limit) + condition += fmt.Sprintf(" LIMIT %d", md.limit) } } if md.offset >= 0 { - s += fmt.Sprintf(" OFFSET %d", md.offset) + condition += fmt.Sprintf(" OFFSET %d", md.offset) } - return s + return } // 组块结果集。 diff --git a/database/gdb/gdb_unit_z_mysql_model_test.go b/database/gdb/gdb_unit_z_mysql_model_test.go index d2ac2bed5..850b1413d 100644 --- a/database/gdb/gdb_unit_z_mysql_model_test.go +++ b/database/gdb/gdb_unit_z_mysql_model_test.go @@ -298,6 +298,57 @@ func Test_Model_Safe(t *testing.T) { gtest.Assert(err, nil) gtest.Assert(count, 2) }) + gtest.Case(t, func() { + md1 := db.Table(table).Safe() + md2 := md1.Where("id in (?)", g.Slice{1, 3}) + count, err := md2.Count() + gtest.Assert(err, nil) + gtest.Assert(count, 2) + + all, err := md2.All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 2) + + all, err = md2.ForPage(1, 10).All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 2) + }) + + gtest.Case(t, func() { + md1 := db.Table(table).Where("id>", 0).Safe() + md2 := md1.Where("id in (?)", g.Slice{1, 3}) + md3 := md1.Where("id in (?)", g.Slice{4, 5, 6}) + // 1,3 + count, err := md2.Count() + gtest.Assert(err, nil) + gtest.Assert(count, 2) + + all, err := md2.OrderBy("id asc").All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 2) + gtest.Assert(all[0]["id"].Int(), 1) + gtest.Assert(all[1]["id"].Int(), 3) + + all, err = md2.ForPage(1, 10).All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 2) + + // 4,5,6 + count, err = md3.Count() + gtest.Assert(err, nil) + gtest.Assert(count, 3) + + all, err = md3.OrderBy("id asc").All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 3) + gtest.Assert(all[0]["id"].Int(), 4) + gtest.Assert(all[1]["id"].Int(), 5) + gtest.Assert(all[2]["id"].Int(), 6) + + all, err = md3.ForPage(1, 10).All() + gtest.Assert(err, nil) + gtest.Assert(len(all), 3) + }) } func Test_Model_All(t *testing.T) { diff --git a/os/gview/gview.go b/os/gview/gview.go index d49dd4c53..376cd8e16 100644 --- a/os/gview/gview.go +++ b/os/gview/gview.go @@ -14,10 +14,8 @@ import ( "errors" "fmt" "github.com/gogf/gf/container/gmap" - "github.com/gogf/gf/internal/intlog" - "sync" - "github.com/gogf/gf/i18n/gi18n" + "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gres" @@ -31,11 +29,11 @@ import ( // View object for template engine. type View struct { - mu sync.RWMutex paths *garray.StrArray // Searching path array, NOT concurrent safe for performance purpose. data map[string]interface{} // Global template variables. funcMap map[string]interface{} // Global template function map. fileCacheMap *gmap.StrAnyMap // File cache map. + defaultFile string // Default template file for parsing. i18nManager *gi18n.Manager // I18n manager for this view. delimiters []string // Customized template delimiters. } @@ -46,8 +44,15 @@ type Params = map[string]interface{} // FuncMap is type for custom template functions. type FuncMap = map[string]interface{} -// Default view object. -var defaultViewObj *View +const ( + // Default template file for parsing. + defaultParsingFile = "index.html" +) + +var ( + // Default view object. + defaultViewObj *View +) // checkAndInitDefaultView checks and initializes the default view object. // The default view object will be initialized just once. @@ -72,6 +77,7 @@ func New(path ...string) *View { data: make(map[string]interface{}), funcMap: make(map[string]interface{}), fileCacheMap: gmap.NewStrAnyMap(true), + defaultFile: defaultParsingFile, i18nManager: gi18n.Instance(), delimiters: make([]string, 2), } diff --git a/os/gview/gview_config.go b/os/gview/gview_config.go index 4b5d3d1cd..4607ddbe7 100644 --- a/os/gview/gview_config.go +++ b/os/gview/gview_config.go @@ -8,55 +8,50 @@ package gview import "github.com/gogf/gf/i18n/gi18n" -// Assign binds multiple template variables to current view object. -// Each goroutine will take effect after the call, so it is concurrent-safe. +// Assign binds multiple global template variables to current view object. +// Note that it's not concurrent-safe, which means it would panic +// if it's called in multiple goroutines in runtime. func (view *View) Assigns(data Params) { - view.mu.Lock() for k, v := range data { view.data[k] = v } - view.mu.Unlock() } -// Assign binds a template variable to current view object. -// Each goroutine will take effect after the call, so it is concurrent-safe. +// Assign binds a global template variable to current view object. +// Note that it's not concurrent-safe, which means it would panic +// if it's called in multiple goroutines in runtime. func (view *View) Assign(key string, value interface{}) { - view.mu.Lock() view.data[key] = value - view.mu.Unlock() +} + +// SetDefaultFile sets default template file for parsing. +func (view *View) SetDefaultFile(file string) { + view.defaultFile = file } // SetDelimiters sets customized delimiters for template parsing. func (view *View) SetDelimiters(left, right string) { - view.mu.Lock() view.delimiters[0] = left view.delimiters[1] = right - view.mu.Unlock() } -// BindFunc registers customized template function named +// BindFunc registers customized global template function named // with given function to current view object. // The is the function name which can be called in template content. func (view *View) BindFunc(name string, function interface{}) { - view.mu.Lock() view.funcMap[name] = function - view.mu.Unlock() } -// BindFuncMap registers customized template functions by map to current view object. +// BindFuncMap registers customized global template functions by map to current view object. // The key of map is the template function name // and the value of map is the address of customized function. func (view *View) BindFuncMap(funcMap FuncMap) { - view.mu.Lock() for k, v := range funcMap { view.funcMap[k] = v } - view.mu.Unlock() } -// SetI18n binds i18n manager to view engine. +// SetI18n binds i18n manager to current view engine. func (view *View) SetI18n(manager *gi18n.Manager) { - view.mu.Lock() view.i18nManager = manager - view.mu.Unlock() } diff --git a/os/gview/gview_doparse.go b/os/gview/gview_doparse.go index cead95ec5..0a5e548ad 100644 --- a/os/gview/gview_doparse.go +++ b/os/gview/gview_doparse.go @@ -52,9 +52,6 @@ type fileCacheItem struct { // with given template parameters and function map // and returns the parsed string content. func (view *View) Parse(file string, params ...Params) (result string, err error) { - view.mu.RLock() - defer view.mu.RUnlock() - var tpl *template.Template // It caches the file, folder and its content to enhance performance. r := view.fileCacheMap.GetOrSetFuncLock(file, func() interface{} { @@ -141,12 +138,15 @@ func (view *View) Parse(file string, params ...Params) (result string, err error return result, nil } +// ParseDefault parses the default template file with params. +func (view *View) ParseDefault(params ...Params) (result string, err error) { + return view.Parse(view.defaultFile, params...) +} + // ParseContent parses given template content // with given template parameters and function map // and returns the parsed content in []byte. func (view *View) ParseContent(content string, params ...Params) (string, error) { - view.mu.RLock() - defer view.mu.RUnlock() err := (error)(nil) tpl := templates.GetOrSetFuncLock(gCONTENT_TEMPLATE_NAME, func() interface{} { return template.New(gCONTENT_TEMPLATE_NAME).Delims(view.delimiters[0], view.delimiters[1]).Funcs(view.funcMap)