From 56368500a3d05b9deb130ddac912b8334af98012 Mon Sep 17 00:00:00 2001 From: John Date: Sat, 10 Aug 2019 16:39:46 +0800 Subject: [PATCH] add header support for session feature of ghttp.Server --- net/ghttp/ghttp_request.go | 17 +++++++++ net/ghttp/ghttp_server_cookie.go | 19 ---------- net/ghttp/ghttp_server_session.go | 18 +++++----- net/ghttp/ghttp_unit_session_test.go | 54 ++++++++++++++++++++++++++-- 4 files changed, 78 insertions(+), 30 deletions(-) diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index b91f60db7..d673df9af 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -227,6 +227,23 @@ func (r *Request) GetUrl() string { return fmt.Sprintf(`%s://%s%s`, scheme, r.Host, r.URL.String()) } +// 从Cookie和Header中查询SESSIONID +func (r *Request) GetSessionId() string { + id := r.Cookie.GetSessionId() + if id == "" { + id = r.Header.Get(r.Server.GetSessionIdName()) + } + return id +} + +// 生成随机的SESSIONID +func (r *Request) MakeSessionId() string { + id := makeSessionId() + r.Cookie.SetSessionId(id) + r.Response.Header().Set(r.Server.GetSessionIdName(), id) + return id +} + // 获得请求来源URL地址 func (r *Request) GetReferer() string { return r.Header.Get("Referer") diff --git a/net/ghttp/ghttp_server_cookie.go b/net/ghttp/ghttp_server_cookie.go index 2ae4f9e32..69c84d701 100644 --- a/net/ghttp/ghttp_server_cookie.go +++ b/net/ghttp/ghttp_server_cookie.go @@ -77,25 +77,6 @@ func (c *Cookie) Map() map[string]string { return m } -// 获取SessionId,不存在时则创建 -func (c *Cookie) SessionId() string { - c.init() - id := c.Get(c.server.GetSessionIdName()) - if id == "" { - id = makeSessionId() - c.SetSessionId(id) - } - return id -} - -// 获取SessionId,不存在时则创建 -func (c *Cookie) MakeSessionId() string { - c.init() - id := makeSessionId() - c.SetSessionId(id) - return id -} - // 判断Cookie中是否存在制定键名(并且没有过期) func (c *Cookie) Contains(key string) bool { c.init() diff --git a/net/ghttp/ghttp_server_session.go b/net/ghttp/ghttp_server_session.go index 34b22d9ea..30686936e 100644 --- a/net/ghttp/ghttp_server_session.go +++ b/net/ghttp/ghttp_server_session.go @@ -55,7 +55,7 @@ func (s *Server) UpdateSession(id string, data map[string]interface{}) { func (s *Session) init() { if len(s.id) == 0 { s.server = s.request.Server - if id := s.request.Cookie.GetSessionId(); id != "" { + if id := s.request.GetSessionId(); id != "" { if v := s.server.sessions.Get(id); v != nil { s.id = id s.data = v.(*gmap.StrAnyMap) @@ -63,7 +63,7 @@ func (s *Session) init() { } } // 否则执行初始化创建 - s.id = s.request.Cookie.MakeSessionId() + s.id = s.request.MakeSessionId() s.data = gmap.NewStrAnyMap(true) s.server.sessions.Set(s.id, s.data, s.server.GetSessionMaxAge()*1000) s.dirty = true @@ -78,7 +78,7 @@ func (s *Session) Id() string { // 获取当前session所有数据,注意是值拷贝 func (s *Session) Map() map[string]interface{} { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() return s.data.Map() } @@ -87,7 +87,7 @@ func (s *Session) Map() map[string]interface{} { // 获得session map大小 func (s *Session) Size() int { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() return s.data.Size() } @@ -110,7 +110,7 @@ func (s *Session) Sets(m map[string]interface{}) { // 判断键名是否存在 func (s *Session) Contains(key string) bool { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() return s.data.Contains(key) } @@ -124,7 +124,7 @@ func (s *Session) IsDirty() bool { // 删除指定session键值对 func (s *Session) Remove(key string) { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() s.data.Remove(key) s.dirty = true @@ -144,7 +144,7 @@ func (s *Session) Restore(data []byte) (err error) { if len(data) == 0 { return nil } - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() s.data.LockFunc(func(m map[string]interface{}) { err = json.Unmarshal(data, &m) @@ -155,7 +155,7 @@ func (s *Session) Restore(data []byte) (err error) { // 清空session func (s *Session) Clear() { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() s.data.Clear() s.dirty = true @@ -171,7 +171,7 @@ func (s *Session) UpdateExpire() { // 获取SESSION变量 func (s *Session) Get(key string, def ...interface{}) interface{} { - if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { + if len(s.id) > 0 || s.request.GetSessionId() != "" { s.init() if v := s.data.Get(key); v != nil { return v diff --git a/net/ghttp/ghttp_unit_session_test.go b/net/ghttp/ghttp_unit_session_test.go index 90e63c206..174947b8d 100644 --- a/net/ghttp/ghttp_unit_session_test.go +++ b/net/ghttp/ghttp_unit_session_test.go @@ -4,7 +4,6 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// SESSION测试 package ghttp_test import ( @@ -17,7 +16,7 @@ import ( "github.com/gogf/gf/test/gtest" ) -func Test_Session(t *testing.T) { +func Test_Session_Cookie(t *testing.T) { p := ports.PopRand() s := g.Server(p) s.BindHandler("/set", func(r *ghttp.Request) { @@ -64,3 +63,54 @@ func Test_Session(t *testing.T) { gtest.Assert(client.GetContent("/get?k=key2"), "") }) } + +func Test_Session_Header(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.BindHandler("/set", func(r *ghttp.Request) { + r.Session.Set(r.Get("k"), r.Get("v")) + }) + s.BindHandler("/get", func(r *ghttp.Request) { + r.Response.Write(r.Session.Get(r.Get("k"))) + }) + s.BindHandler("/remove", func(r *ghttp.Request) { + r.Session.Remove(r.Get("k")) + }) + s.BindHandler("/clear", func(r *ghttp.Request) { + r.Session.Clear() + }) + s.SetPort(p) + s.SetDumpRouteMap(false) + s.Start() + defer s.Shutdown() + + // 等待启动完成 + time.Sleep(200 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + response, e1 := client.Get("/set?k=key1&v=100") + if response != nil { + defer response.Close() + } + sessionId := response.Header.Get(s.GetSessionIdName()) + gtest.Assert(e1, nil) + gtest.AssertNE(sessionId, nil) + gtest.Assert(response.ReadAllString(), "") + + client.SetHeader(s.GetSessionIdName(), sessionId) + + gtest.Assert(client.GetContent("/set?k=key2&v=200"), "") + + gtest.Assert(client.GetContent("/get?k=key1"), "100") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + gtest.Assert(client.GetContent("/get?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key1"), "") + gtest.Assert(client.GetContent("/remove?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key4"), "") + gtest.Assert(client.GetContent("/get?k=key1"), "") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + gtest.Assert(client.GetContent("/clear"), "") + gtest.Assert(client.GetContent("/get?k=key2"), "") + }) +}