mirror of
https://gitee.com/johng/gf.git
synced 2024-12-03 04:37:49 +08:00
add header support for session feature of ghttp.Server
This commit is contained in:
parent
846c6a579e
commit
56368500a3
@ -227,6 +227,23 @@ func (r *Request) GetUrl() string {
|
||||
return fmt.Sprintf(`%s://%s%s`, scheme, r.Host, r.URL.String())
|
||||
}
|
||||
|
||||
// 从Cookie和Header中查询SESSIONID
|
||||
func (r *Request) GetSessionId() string {
|
||||
id := r.Cookie.GetSessionId()
|
||||
if id == "" {
|
||||
id = r.Header.Get(r.Server.GetSessionIdName())
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// 生成随机的SESSIONID
|
||||
func (r *Request) MakeSessionId() string {
|
||||
id := makeSessionId()
|
||||
r.Cookie.SetSessionId(id)
|
||||
r.Response.Header().Set(r.Server.GetSessionIdName(), id)
|
||||
return id
|
||||
}
|
||||
|
||||
// 获得请求来源URL地址
|
||||
func (r *Request) GetReferer() string {
|
||||
return r.Header.Get("Referer")
|
||||
|
@ -77,25 +77,6 @@ func (c *Cookie) Map() map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
// 获取SessionId,不存在时则创建
|
||||
func (c *Cookie) SessionId() string {
|
||||
c.init()
|
||||
id := c.Get(c.server.GetSessionIdName())
|
||||
if id == "" {
|
||||
id = makeSessionId()
|
||||
c.SetSessionId(id)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// 获取SessionId,不存在时则创建
|
||||
func (c *Cookie) MakeSessionId() string {
|
||||
c.init()
|
||||
id := makeSessionId()
|
||||
c.SetSessionId(id)
|
||||
return id
|
||||
}
|
||||
|
||||
// 判断Cookie中是否存在制定键名(并且没有过期)
|
||||
func (c *Cookie) Contains(key string) bool {
|
||||
c.init()
|
||||
|
@ -55,7 +55,7 @@ func (s *Server) UpdateSession(id string, data map[string]interface{}) {
|
||||
func (s *Session) init() {
|
||||
if len(s.id) == 0 {
|
||||
s.server = s.request.Server
|
||||
if id := s.request.Cookie.GetSessionId(); id != "" {
|
||||
if id := s.request.GetSessionId(); id != "" {
|
||||
if v := s.server.sessions.Get(id); v != nil {
|
||||
s.id = id
|
||||
s.data = v.(*gmap.StrAnyMap)
|
||||
@ -63,7 +63,7 @@ func (s *Session) init() {
|
||||
}
|
||||
}
|
||||
// 否则执行初始化创建
|
||||
s.id = s.request.Cookie.MakeSessionId()
|
||||
s.id = s.request.MakeSessionId()
|
||||
s.data = gmap.NewStrAnyMap(true)
|
||||
s.server.sessions.Set(s.id, s.data, s.server.GetSessionMaxAge()*1000)
|
||||
s.dirty = true
|
||||
@ -78,7 +78,7 @@ func (s *Session) Id() string {
|
||||
|
||||
// 获取当前session所有数据,注意是值拷贝
|
||||
func (s *Session) Map() map[string]interface{} {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
return s.data.Map()
|
||||
}
|
||||
@ -87,7 +87,7 @@ func (s *Session) Map() map[string]interface{} {
|
||||
|
||||
// 获得session map大小
|
||||
func (s *Session) Size() int {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
return s.data.Size()
|
||||
}
|
||||
@ -110,7 +110,7 @@ func (s *Session) Sets(m map[string]interface{}) {
|
||||
|
||||
// 判断键名是否存在
|
||||
func (s *Session) Contains(key string) bool {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
return s.data.Contains(key)
|
||||
}
|
||||
@ -124,7 +124,7 @@ func (s *Session) IsDirty() bool {
|
||||
|
||||
// 删除指定session键值对
|
||||
func (s *Session) Remove(key string) {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
s.data.Remove(key)
|
||||
s.dirty = true
|
||||
@ -144,7 +144,7 @@ func (s *Session) Restore(data []byte) (err error) {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
s.data.LockFunc(func(m map[string]interface{}) {
|
||||
err = json.Unmarshal(data, &m)
|
||||
@ -155,7 +155,7 @@ func (s *Session) Restore(data []byte) (err error) {
|
||||
|
||||
// 清空session
|
||||
func (s *Session) Clear() {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
s.data.Clear()
|
||||
s.dirty = true
|
||||
@ -171,7 +171,7 @@ func (s *Session) UpdateExpire() {
|
||||
|
||||
// 获取SESSION变量
|
||||
func (s *Session) Get(key string, def ...interface{}) interface{} {
|
||||
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
|
||||
if len(s.id) > 0 || s.request.GetSessionId() != "" {
|
||||
s.init()
|
||||
if v := s.data.Get(key); v != nil {
|
||||
return v
|
||||
|
@ -4,7 +4,6 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
// SESSION测试
|
||||
package ghttp_test
|
||||
|
||||
import (
|
||||
@ -17,7 +16,7 @@ import (
|
||||
"github.com/gogf/gf/test/gtest"
|
||||
)
|
||||
|
||||
func Test_Session(t *testing.T) {
|
||||
func Test_Session_Cookie(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.BindHandler("/set", func(r *ghttp.Request) {
|
||||
@ -64,3 +63,54 @@ func Test_Session(t *testing.T) {
|
||||
gtest.Assert(client.GetContent("/get?k=key2"), "")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Session_Header(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.BindHandler("/set", func(r *ghttp.Request) {
|
||||
r.Session.Set(r.Get("k"), r.Get("v"))
|
||||
})
|
||||
s.BindHandler("/get", func(r *ghttp.Request) {
|
||||
r.Response.Write(r.Session.Get(r.Get("k")))
|
||||
})
|
||||
s.BindHandler("/remove", func(r *ghttp.Request) {
|
||||
r.Session.Remove(r.Get("k"))
|
||||
})
|
||||
s.BindHandler("/clear", func(r *ghttp.Request) {
|
||||
r.Session.Clear()
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouteMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
|
||||
// 等待启动完成
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
response, e1 := client.Get("/set?k=key1&v=100")
|
||||
if response != nil {
|
||||
defer response.Close()
|
||||
}
|
||||
sessionId := response.Header.Get(s.GetSessionIdName())
|
||||
gtest.Assert(e1, nil)
|
||||
gtest.AssertNE(sessionId, nil)
|
||||
gtest.Assert(response.ReadAllString(), "")
|
||||
|
||||
client.SetHeader(s.GetSessionIdName(), sessionId)
|
||||
|
||||
gtest.Assert(client.GetContent("/set?k=key2&v=200"), "")
|
||||
|
||||
gtest.Assert(client.GetContent("/get?k=key1"), "100")
|
||||
gtest.Assert(client.GetContent("/get?k=key2"), "200")
|
||||
gtest.Assert(client.GetContent("/get?k=key3"), "")
|
||||
gtest.Assert(client.GetContent("/remove?k=key1"), "")
|
||||
gtest.Assert(client.GetContent("/remove?k=key3"), "")
|
||||
gtest.Assert(client.GetContent("/remove?k=key4"), "")
|
||||
gtest.Assert(client.GetContent("/get?k=key1"), "")
|
||||
gtest.Assert(client.GetContent("/get?k=key2"), "200")
|
||||
gtest.Assert(client.GetContent("/clear"), "")
|
||||
gtest.Assert(client.GetContent("/get?k=key2"), "")
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user