improve session feature by allowing custom session id creating function for ghttp.Server

This commit is contained in:
john 2020-05-17 14:33:21 +08:00
parent 351de5ee6a
commit 45a94d23d5
6 changed files with 70 additions and 18 deletions

1
go.mod
View File

@ -11,6 +11,7 @@ require (
github.com/gorilla/websocket v1.4.1
github.com/gqcn/structs v1.1.1
github.com/grokify/html-strip-tags-go v0.0.0-20190921062105-daaa06bf1aaf
github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/olekukonko/tablewriter v0.0.1
golang.org/x/text v0.3.2
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c

View File

@ -9,12 +9,14 @@ package ghttp
import (
"context"
"fmt"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/os/gres"
"github.com/gogf/gf/os/gsession"
"github.com/gogf/gf/os/gview"
"github.com/gogf/gf/util/guid"
"net/http"
"strings"
"github.com/gogf/gf/os/gsession"
"time"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gregex"
@ -75,6 +77,19 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request {
request.Middleware = &Middleware{
request: request,
}
// Custom session id creating function.
err := request.Session.SetIdFunc(func(ttl time.Duration) string {
var (
agent = request.UserAgent()
address = request.RemoteAddr
cookie = request.Header.Get("Cookie")
)
intlog.Print(agent, address, cookie)
return guid.S([]byte(agent), []byte(address), []byte(cookie))
})
if err != nil {
panic(err)
}
return request
}

View File

@ -8,5 +8,6 @@ package ghttp
import "github.com/gogf/gf/os/gsession"
// Session is actually a alias of gsession.Session.
// Session is actually a alias of gsession.Session,
// which is bound to a single request.
type Session = gsession.Session

View File

@ -16,7 +16,7 @@ import (
// Manager for sessions.
type Manager struct {
ttl time.Duration // TTL for sessions.
storage Storage // Storage interface for session storage Set/Get.
storage Storage // Storage interface for session storage.
sessionData *gcache.Cache // Session data cache for session TTL.
}

View File

@ -7,6 +7,7 @@
package gsession
import (
"errors"
"github.com/gogf/gf/internal/intlog"
"time"
@ -16,22 +17,27 @@ import (
"github.com/gogf/gf/util/gconv"
)
// Session struct for storing single session data.
// Session struct for storing single session data,
// which is bound to a single request.
type Session struct {
id string // Session id.
data *gmap.StrAnyMap // Session data.
dirty bool // Used to mark session is modified.
start bool // Used to mark session is started.
manager *Manager // Parent manager.
// idFunc is a callback function used for creating custom session id.
// This is called if session id is empty ever when session starts.
idFunc func(ttl time.Duration) (id string)
}
// init does the delay initialization for session.
// init does the lazy initialization for session.
// It here initializes real session if necessary.
func (s *Session) init() {
if s.start {
return
}
if len(s.id) > 0 {
if s.id != "" {
var err error
// Retrieve memory session data from manager.
if r := s.manager.sessionData.Get(s.id); r != nil {
@ -50,10 +56,16 @@ func (s *Session) init() {
s.id = ""
}
}
if len(s.id) == 0 {
// Use custom session id creating function.
if s.id == "" && s.idFunc != nil {
s.id = s.idFunc(s.manager.ttl)
}
// Use default session id creating function of storage.
if s.id == "" {
s.id = s.manager.storage.New(s.manager.ttl)
}
if len(s.id) == 0 {
// Use default session id creating function.
if s.id == "" {
s.id = NewSessionId()
}
if s.data == nil {
@ -67,7 +79,7 @@ func (s *Session) init() {
//
// NOTE that this function must be called ever after a session request done.
func (s *Session) Close() {
if s.start && len(s.id) > 0 {
if s.start && s.id != "" {
size := s.data.Size()
if s.manager.storage != nil {
if s.dirty {
@ -116,7 +128,7 @@ func (s *Session) Sets(data map[string]interface{}) error {
// Remove removes key along with its value from this session.
func (s *Session) Remove(key string) error {
if len(s.id) == 0 {
if s.id == "" {
return nil
}
s.init()
@ -138,7 +150,7 @@ func (s *Session) Clear() error {
// RemoveAll deletes all key-value pairs from this session.
func (s *Session) RemoveAll() error {
if len(s.id) == 0 {
if s.id == "" {
return nil
}
s.init()
@ -160,10 +172,30 @@ func (s *Session) Id() string {
return s.id
}
// SetId sets custom session before session starts.
// It returns error if it is called after session starts.
func (s *Session) SetId(id string) error {
if s.start {
return errors.New("session already started")
}
s.id = id
return nil
}
// SetIdFunc sets custom session id creating function before session starts.
// It returns error if it is called after session starts.
func (s *Session) SetIdFunc(f func(ttl time.Duration) string) error {
if s.start {
return errors.New("session already started")
}
s.idFunc = f
return nil
}
// Map returns all data as map.
// Note that it's using value copy internally for concurrent-safe purpose.
func (s *Session) Map() map[string]interface{} {
if len(s.id) > 0 {
if s.id != "" {
s.init()
if data := s.manager.storage.GetMap(s.id); data != nil {
return data
@ -175,7 +207,7 @@ func (s *Session) Map() map[string]interface{} {
// Size returns the size of the session.
func (s *Session) Size() int {
if len(s.id) > 0 {
if s.id != "" {
s.init()
if size := s.manager.storage.GetSize(s.id); size >= 0 {
return size
@ -200,7 +232,7 @@ func (s *Session) IsDirty() bool {
// It returns <def> if the key does not exist in the session if <def> is given,
// or else it return nil.
func (s *Session) Get(key string, def ...interface{}) interface{} {
if len(s.id) == 0 {
if s.id == "" {
return nil
}
s.init()

View File

@ -26,7 +26,7 @@ import (
var (
sequence gtype.Uint32 // Sequence for unique purpose of current process.
sequenceMax = uint32(1000000) // Sequence max.
randomStrBase = "0123456789abcdefghijklmnopqrstuvwxyz" // 36
randomStrBase = "0123456789abcdefghijklmnopqrstuvwxyz" // Random chars string(36 bytes).
macAddrStr = "0000000" // MAC addresses hash result in 7 bytes.
processIdStr = "0000" // Process id in 4 bytes.
)
@ -82,8 +82,11 @@ func S(data ...[]byte) string {
} else if len(data) <= 3 {
n := 0
for i, v := range data {
copy(b[i*7:], getDataHashStr(v))
n += 7
// Ignore empty data item bytes.
if len(v) > 0 {
copy(b[i*7:], getDataHashStr(v))
n += 7
}
}
copy(b[n:], nanoStr)
copy(b[n+12:], getRandomStr(36-n-12))