improve session expire algorithm for gsession

This commit is contained in:
john 2019-09-15 12:17:44 +08:00
parent f128cb9f61
commit 1665d92136
8 changed files with 22 additions and 11 deletions

View File

@ -4,11 +4,12 @@ import (
"github.com/gogf/gf/frame/g" "github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/os/gtime" "github.com/gogf/gf/os/gtime"
"time"
) )
func main() { func main() {
s := g.Server() s := g.Server()
s.SetSessionMaxAge(60) s.SetSessionMaxAge(2 * time.Second)
s.BindHandler("/set", func(r *ghttp.Request) { s.BindHandler("/set", func(r *ghttp.Request) {
r.Session.Set("time", gtime.Second()) r.Session.Set("time", gtime.Second())
r.Response.Write("ok") r.Response.Write("ok")

View File

@ -214,7 +214,7 @@ func GetServer(name ...interface{}) *Server {
} }
config := defaultServerConfig config := defaultServerConfig
if config.SessionStorage == nil { if config.SessionStorage == nil {
config.SessionStorage = gsession.NewStorageFile(config.SessionMaxAge) config.SessionStorage = gsession.NewStorageFile()
} }
s := &Server{ s := &Server{
name: serverName, name: serverName,

View File

@ -21,6 +21,7 @@ func (s *Server) SetSessionMaxAge(ttl time.Duration) {
return return
} }
s.config.SessionMaxAge = ttl s.config.SessionMaxAge = ttl
s.sessionManager.SetTTL(ttl)
} }
// 设置http server参数 - SessionIdName // 设置http server参数 - SessionIdName
@ -39,6 +40,7 @@ func (s *Server) SetSessionStorage(storage gsession.Storage) {
return return
} }
s.config.SessionStorage = storage s.config.SessionStorage = storage
s.sessionManager.SetStorage(storage)
} }
// 获取http server参数 - SessionMaxAge // 获取http server参数 - SessionMaxAge

View File

@ -23,7 +23,7 @@ type Manager struct {
func New(ttl time.Duration, storage ...Storage) *Manager { func New(ttl time.Duration, storage ...Storage) *Manager {
m := &Manager{ m := &Manager{
ttl: ttl, ttl: ttl,
storage: NewStorageFile(ttl), storage: NewStorageFile(),
sessions: gcache.New(), sessions: gcache.New(),
} }
if len(storage) > 0 && storage[0] != nil { if len(storage) > 0 && storage[0] != nil {
@ -53,6 +53,11 @@ func (m *Manager) SetStorage(storage Storage) {
m.storage = storage m.storage = storage
} }
// SetTTL the TTL for the session manager.
func (m *Manager) SetTTL(ttl time.Duration) {
m.ttl = ttl
}
// TTL returns the TTL of the session manager. // TTL returns the TTL of the session manager.
func (m *Manager) TTL() time.Duration { func (m *Manager) TTL() time.Duration {
return m.ttl return m.ttl

View File

@ -32,7 +32,7 @@ func (s *Session) init() {
s.dirty = gtype.NewBool(false) s.dirty = gtype.NewBool(false)
} }
if len(s.id) > 0 && s.data == nil { if len(s.id) > 0 && s.data == nil {
if data := s.manager.storage.GetSession(s.id); data != nil { if data := s.manager.storage.GetSession(s.id, s.manager.ttl); data != nil {
if s.data = gmap.NewStrAnyMapFrom(data, true); s.data == nil { if s.data = gmap.NewStrAnyMapFrom(data, true); s.data == nil {
panic("session restoring failed for id:" + s.id) panic("session restoring failed for id:" + s.id)
} }

View File

@ -6,6 +6,8 @@
package gsession package gsession
import "time"
type Storage interface { type Storage interface {
// Get retrieves session value with given key. // Get retrieves session value with given key.
// It returns nil if the key does not exist in the session. // It returns nil if the key does not exist in the session.
@ -26,10 +28,13 @@ type Storage interface {
RemoveAll() error RemoveAll() error
// GetSession returns the session data bytes for given session id. // GetSession returns the session data bytes for given session id.
GetSession(id string) map[string]interface{} // The parameter specifies the TTL for this session.
// It returns nil if the TTL is exceeded.
GetSession(id string, ttl time.Duration) map[string]interface{}
// SetSession updates the content for session id. // SetSession updates the content for session id.
// Note that the parameter <content> is the serialized bytes for session map. // Note that the parameter <content> is the serialized bytes for session map.
SetSession(id string, data map[string]interface{}) error SetSession(id string, data map[string]interface{}) error
// UpdateTTL updates the TTL for specified session id. // UpdateTTL updates the TTL for specified session id.
UpdateTTL(id string) error UpdateTTL(id string) error
} }

View File

@ -28,7 +28,6 @@ import (
// StorageFile implements the Session Storage interface with file system. // StorageFile implements the Session Storage interface with file system.
type StorageFile struct { type StorageFile struct {
ttl time.Duration
path string path string
cryptoKey []byte cryptoKey []byte
cryptoEnabled bool cryptoEnabled bool
@ -51,7 +50,7 @@ func init() {
} }
// NewStorageFile creates and returns a file storage object for session. // NewStorageFile creates and returns a file storage object for session.
func NewStorageFile(ttl time.Duration, path ...string) *StorageFile { func NewStorageFile(path ...string) *StorageFile {
storagePath := DefaultStorageFilePath storagePath := DefaultStorageFilePath
if len(path) > 0 && path[0] != "" { if len(path) > 0 && path[0] != "" {
storagePath, _ = gfile.Search(path[0]) storagePath, _ = gfile.Search(path[0])
@ -68,7 +67,6 @@ func NewStorageFile(ttl time.Duration, path ...string) *StorageFile {
} }
} }
s := &StorageFile{ s := &StorageFile{
ttl: ttl,
path: storagePath, path: storagePath,
cryptoKey: DefaultStorageFileCryptoKey, cryptoKey: DefaultStorageFileCryptoKey,
cryptoEnabled: DefaultStorageFileCryptoEnabled, cryptoEnabled: DefaultStorageFileCryptoEnabled,
@ -135,12 +133,12 @@ func (s *StorageFile) RemoveAll() error {
} }
// GetSession return the session data for given session id. // GetSession return the session data for given session id.
func (s *StorageFile) GetSession(id string) map[string]interface{} { func (s *StorageFile) GetSession(id string, ttl time.Duration) map[string]interface{} {
path := s.sessionFilePath(id) path := s.sessionFilePath(id)
data := gfile.GetBytes(path) data := gfile.GetBytes(path)
if len(data) > 8 { if len(data) > 8 {
timestampMilli := gbinary.DecodeToInt64(data[:8]) timestampMilli := gbinary.DecodeToInt64(data[:8])
if timestampMilli+s.ttl.Nanoseconds()/1e6 < gtime.Millisecond() { if timestampMilli+ttl.Nanoseconds()/1e6 < gtime.Millisecond() {
return nil return nil
} }
var err error var err error

View File

@ -15,7 +15,7 @@ import (
func Test_Manager_Basic(t *testing.T) { func Test_Manager_Basic(t *testing.T) {
ttl := time.Second ttl := time.Second
storage := NewStorageFile(ttl) storage := NewStorageFile()
manager := New(ttl, storage) manager := New(ttl, storage)
sessionId := "" sessionId := ""
gtest.Case(t, func() { gtest.Case(t, func() {