add header support for session feature of ghttp.Server

This commit is contained in:
John 2019-08-10 16:39:46 +08:00
parent 846c6a579e
commit 56368500a3
4 changed files with 78 additions and 30 deletions

View File

@ -227,6 +227,23 @@ func (r *Request) GetUrl() string {
return fmt.Sprintf(`%s://%s%s`, scheme, r.Host, r.URL.String())
}
// 从Cookie和Header中查询SESSIONID
func (r *Request) GetSessionId() string {
id := r.Cookie.GetSessionId()
if id == "" {
id = r.Header.Get(r.Server.GetSessionIdName())
}
return id
}
// 生成随机的SESSIONID
func (r *Request) MakeSessionId() string {
id := makeSessionId()
r.Cookie.SetSessionId(id)
r.Response.Header().Set(r.Server.GetSessionIdName(), id)
return id
}
// 获得请求来源URL地址
func (r *Request) GetReferer() string {
return r.Header.Get("Referer")

View File

@ -77,25 +77,6 @@ func (c *Cookie) Map() map[string]string {
return m
}
// 获取SessionId不存在时则创建
func (c *Cookie) SessionId() string {
c.init()
id := c.Get(c.server.GetSessionIdName())
if id == "" {
id = makeSessionId()
c.SetSessionId(id)
}
return id
}
// 获取SessionId不存在时则创建
func (c *Cookie) MakeSessionId() string {
c.init()
id := makeSessionId()
c.SetSessionId(id)
return id
}
// 判断Cookie中是否存在制定键名(并且没有过期)
func (c *Cookie) Contains(key string) bool {
c.init()

View File

@ -55,7 +55,7 @@ func (s *Server) UpdateSession(id string, data map[string]interface{}) {
func (s *Session) init() {
if len(s.id) == 0 {
s.server = s.request.Server
if id := s.request.Cookie.GetSessionId(); id != "" {
if id := s.request.GetSessionId(); id != "" {
if v := s.server.sessions.Get(id); v != nil {
s.id = id
s.data = v.(*gmap.StrAnyMap)
@ -63,7 +63,7 @@ func (s *Session) init() {
}
}
// 否则执行初始化创建
s.id = s.request.Cookie.MakeSessionId()
s.id = s.request.MakeSessionId()
s.data = gmap.NewStrAnyMap(true)
s.server.sessions.Set(s.id, s.data, s.server.GetSessionMaxAge()*1000)
s.dirty = true
@ -78,7 +78,7 @@ func (s *Session) Id() string {
// 获取当前session所有数据注意是值拷贝
func (s *Session) Map() map[string]interface{} {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
return s.data.Map()
}
@ -87,7 +87,7 @@ func (s *Session) Map() map[string]interface{} {
// 获得session map大小
func (s *Session) Size() int {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
return s.data.Size()
}
@ -110,7 +110,7 @@ func (s *Session) Sets(m map[string]interface{}) {
// 判断键名是否存在
func (s *Session) Contains(key string) bool {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
return s.data.Contains(key)
}
@ -124,7 +124,7 @@ func (s *Session) IsDirty() bool {
// 删除指定session键值对
func (s *Session) Remove(key string) {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
s.data.Remove(key)
s.dirty = true
@ -144,7 +144,7 @@ func (s *Session) Restore(data []byte) (err error) {
if len(data) == 0 {
return nil
}
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
s.data.LockFunc(func(m map[string]interface{}) {
err = json.Unmarshal(data, &m)
@ -155,7 +155,7 @@ func (s *Session) Restore(data []byte) (err error) {
// 清空session
func (s *Session) Clear() {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
s.data.Clear()
s.dirty = true
@ -171,7 +171,7 @@ func (s *Session) UpdateExpire() {
// 获取SESSION变量
func (s *Session) Get(key string, def ...interface{}) interface{} {
if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" {
if len(s.id) > 0 || s.request.GetSessionId() != "" {
s.init()
if v := s.data.Get(key); v != nil {
return v

View File

@ -4,7 +4,6 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// SESSION测试
package ghttp_test
import (
@ -17,7 +16,7 @@ import (
"github.com/gogf/gf/test/gtest"
)
func Test_Session(t *testing.T) {
func Test_Session_Cookie(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.BindHandler("/set", func(r *ghttp.Request) {
@ -64,3 +63,54 @@ func Test_Session(t *testing.T) {
gtest.Assert(client.GetContent("/get?k=key2"), "")
})
}
func Test_Session_Header(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.BindHandler("/set", func(r *ghttp.Request) {
r.Session.Set(r.Get("k"), r.Get("v"))
})
s.BindHandler("/get", func(r *ghttp.Request) {
r.Response.Write(r.Session.Get(r.Get("k")))
})
s.BindHandler("/remove", func(r *ghttp.Request) {
r.Session.Remove(r.Get("k"))
})
s.BindHandler("/clear", func(r *ghttp.Request) {
r.Session.Clear()
})
s.SetPort(p)
s.SetDumpRouteMap(false)
s.Start()
defer s.Shutdown()
// 等待启动完成
time.Sleep(200 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
response, e1 := client.Get("/set?k=key1&v=100")
if response != nil {
defer response.Close()
}
sessionId := response.Header.Get(s.GetSessionIdName())
gtest.Assert(e1, nil)
gtest.AssertNE(sessionId, nil)
gtest.Assert(response.ReadAllString(), "")
client.SetHeader(s.GetSessionIdName(), sessionId)
gtest.Assert(client.GetContent("/set?k=key2&v=200"), "")
gtest.Assert(client.GetContent("/get?k=key1"), "100")
gtest.Assert(client.GetContent("/get?k=key2"), "200")
gtest.Assert(client.GetContent("/get?k=key3"), "")
gtest.Assert(client.GetContent("/remove?k=key1"), "")
gtest.Assert(client.GetContent("/remove?k=key3"), "")
gtest.Assert(client.GetContent("/remove?k=key4"), "")
gtest.Assert(client.GetContent("/get?k=key1"), "")
gtest.Assert(client.GetContent("/get?k=key2"), "200")
gtest.Assert(client.GetContent("/clear"), "")
gtest.Assert(client.GetContent("/get?k=key2"), "")
})
}