From 4e7c6c1fb4fffe97fe158904720e589429334668 Mon Sep 17 00:00:00 2001 From: John Date: Wed, 4 Mar 2020 22:52:56 +0800 Subject: [PATCH] improve CORS feature for ghttp.Server --- net/ghttp/ghttp_request_middleware.go | 1 + net/ghttp/ghttp_response_cors.go | 53 ++++++++----- net/ghttp/ghttp_server.go | 1 + net/ghttp/ghttp_server_router.go | 61 +++++---------- net/ghttp/ghttp_server_router_serve.go | 35 ++++++--- ...go => ghttp_unit_middleware_basic_test.go} | 35 ++++++++- net/ghttp/ghttp_unit_middleware_cors_test.go | 76 +++++++++++++++++++ net/ghttp/ghttp_unit_router_hook_test.go | 23 +++--- 8 files changed, 196 insertions(+), 89 deletions(-) rename net/ghttp/{ghttp_unit_middleware_test.go => ghttp_unit_middleware_basic_test.go} (95%) create mode 100644 net/ghttp/ghttp_unit_middleware_cors_test.go diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index f4e4907e9..050c16f0e 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -123,6 +123,7 @@ func (m *Middleware) Next() { }, func(exception interface{}) { m.request.error = gerror.Newf("%v", exception) m.request.Response.WriteStatus(http.StatusInternalServerError, exception) + loop = false }) } // Check the http status code after all handler and middleware done. diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index 835a1c8e2..59fc35d5d 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -8,11 +8,10 @@ package ghttp import ( - "net/http" - "net/url" - "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" + "net/http" + "net/url" ) // CORSOptions is the options for CORS feature. @@ -29,18 +28,18 @@ type CORSOptions struct { var ( // defaultAllowHeaders is the default allowed headers for CORS. - // It's defined as map for better header key searching performance. - defaultAllowHeaders = map[string]struct{}{ - "Origin": {}, - "Accept": {}, - "Cookie": {}, - "Authorization": {}, - "X-Auth-Token": {}, - "X-Requested-With": {}, - "Content-Type": {}, - } + // It's defined another map for better header key searching performance. + defaultAllowHeaders = "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With" + defaultAllowHeadersMap = make(map[string]struct{}) ) +func init() { + array := gstr.SplitAndTrim(defaultAllowHeaders, ",") + for _, header := range array { + defaultAllowHeadersMap[header] = struct{}{} + } +} + // DefaultCORSOptions returns the default CORS options, // which allows any cross-domain request. func (r *Response) DefaultCORSOptions() CORSOptions { @@ -48,22 +47,17 @@ func (r *Response) DefaultCORSOptions() CORSOptions { AllowOrigin: "*", AllowMethods: HTTP_METHODS, AllowCredentials: "true", + AllowHeaders: defaultAllowHeaders, MaxAge: 3628800, } // Allow all client's custom headers in default. if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { array := gstr.SplitAndTrim(headers, ",") for _, header := range array { - if _, ok := defaultAllowHeaders[header]; !ok { + if _, ok := defaultAllowHeadersMap[header]; !ok { options.AllowHeaders += header + "," } } - for header, _ := range defaultAllowHeaders { - if len(options.AllowHeaders) > 0 { - options.AllowHeaders += "," - } - options.AllowHeaders += header - } } // Allow all anywhere origin in default. if origin := r.Request.Header.Get("Origin"); origin != "" { @@ -101,8 +95,25 @@ func (r *Response) CORS(options CORSOptions) { } // No continue service handling if it's OPTIONS request. if gstr.Equal(r.Request.Method, "OPTIONS") { + // Request method's handler searching. + // It here uses Server.routesMap attribute enhancing the searching performance. + if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" { + routerKey := "" + for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} { + for _, v := range []string{gDEFAULT_METHOD, method} { + routerKey = r.Server.handlerKey("", v, r.Request.URL.Path, domain) + if r.Server.routesMap[routerKey] != nil { + if r.Status == 0 { + r.Status = http.StatusOK + } + r.Request.ExitAll() + } + } + } + } + // Cannot find the request handler. if r.Status == 0 { - r.Status = http.StatusOK + r.Status = http.StatusNotFound } r.Request.ExitAll() } diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index d79aa80c7..b446a17e8 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -79,6 +79,7 @@ type ( // handlerItem is the registered handler for route handling, // including middleware and hook functions. handlerItem struct { + itemId int // Unique handler item id mark. itemName string // Handler name, which is automatically retrieved from runtime stack when registered. itemType int // Handler type: object/handler/controller/middleware/hook. itemFunc HandlerFunc // Handler address. diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 7e7c2f353..29426863c 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -9,7 +9,7 @@ package ghttp import ( "errors" "fmt" - "github.com/gogf/gf/util/gutil" + "github.com/gogf/gf/container/gtype" "strings" "github.com/gogf/gf/debug/gdebug" @@ -23,6 +23,11 @@ const ( gFILTER_KEY = "/net/ghttp/ghttp" ) +var ( + // handlerIdGenerator is handler item id generator. + handlerIdGenerator = gtype.NewInt() +) + // handlerKey creates and returns an unique router key for given parameters. func (s *Server) handlerKey(hook, method, path, domain string) string { return hook + "%" + s.serveHandlerKey(method, path, domain) @@ -59,6 +64,7 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err // This function is called during server starts up, which cares little about the performance. What really cares // is the well designed router storage structure for router searching when the request is under serving. func (s *Server) setHandler(pattern string, handler *handlerItem) { + handler.itemId = handlerIdGenerator.Add(1) domain, method, uri, err := s.parsePattern(pattern) if err != nil { s.Logger().Fatal("invalid pattern:", pattern, err) @@ -70,11 +76,11 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } // Repeated router checks, this feature can be disabled by server configuration. - regKey := s.handlerKey(handler.hookName, method, uri, domain) + routerKey := s.handlerKey(handler.hookName, method, uri, domain) if !s.config.RouteOverWrite { switch handler.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: - if item, ok := s.routesMap[regKey]; ok { + if item, ok := s.routesMap[routerKey]; ok { s.Logger().Fatalf(`duplicated route registry "%s", already registered at %s`, pattern, item[0].file) return } @@ -143,47 +149,14 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { // fuzzy checks. if i == len(array)-1 && part != "*fuzz" { if v, ok := p.(map[string]interface{})["*list"]; !ok { - list := glist.New() - p.(map[string]interface{})["*list"] = list - lists = append(lists, list) + leafList := glist.New() + p.(map[string]interface{})["*list"] = leafList + lists = append(lists, leafList) } else { lists = append(lists, v.(*glist.List)) } } } - - for k, v := range array { - if len(v) == 0 { - continue - } - // 判断是否模糊匹配规则 - if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, v) { - v = "*fuzz" - // 由于是模糊规则,因此这里会有一个*list,用以将后续的路由规则加进来, - // 检索会从叶子节点的链表往根节点按照优先级进行检索 - if v, ok := p.(map[string]interface{})["*list"]; !ok { - p.(map[string]interface{})["*list"] = glist.New() - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) - } else { - lists = append(lists, v.(*glist.List)) - } - } - // 属性层级数据写入 - if _, ok := p.(map[string]interface{})[v]; !ok { - p.(map[string]interface{})[v] = make(map[string]interface{}) - } - p = p.(map[string]interface{})[v] - // 到达叶子节点,往list中增加匹配规则(条件 v != "*fuzz" 是因为模糊节点的话在前面已经添加了*list链表) - if k == len(array)-1 && v != "*fuzz" { - if v, ok := p.(map[string]interface{})["*list"]; !ok { - p.(map[string]interface{})["*list"] = glist.New() - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) - } else { - lists = append(lists, v.(*glist.List)) - } - } - } - // It iterates the list array of , compares priorities and inserts the new router item in // the proper position of each list. The priority of the list is ordered from high to low. item := (*handlerItem)(nil) @@ -206,8 +179,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } } // Initialize the route map item. - if _, ok := s.routesMap[regKey]; !ok { - s.routesMap[regKey] = make([]registeredRouteItem, 0) + if _, ok := s.routesMap[routerKey]; !ok { + s.routesMap[routerKey] = make([]registeredRouteItem, 0) } _, file, line := gdebug.CallerWithFilter(gFILTER_KEY) routeItem := registeredRouteItem{ @@ -217,12 +190,12 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { switch handler.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: // Overwrite the route. - s.routesMap[regKey] = []registeredRouteItem{routeItem} + s.routesMap[routerKey] = []registeredRouteItem{routeItem} default: // Append the route. - s.routesMap[regKey] = append(s.routesMap[regKey], routeItem) + s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem) } - gutil.Dump(s.serveTree) + //gutil.Dump(s.serveTree) } // 对比两个handlerItem的优先级,需要非常注意的是,注意新老对比项的参数先后顺序。 diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index 7306a5614..0fb563b8a 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -15,7 +15,7 @@ import ( "github.com/gogf/gf/text/gregex" ) -// handlerCacheItem is a item for router cache. +// handlerCacheItem is an item for router searching cache. type handlerCacheItem struct { parsedItems []*handlerParsedItem hasHook bool @@ -24,8 +24,17 @@ type handlerCacheItem struct { // getHandlersWithCache searches the router item with cache feature for given request. func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { - value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} { - parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost()) + method := r.Method + // Special http method OPTIONS handling. + // It searches the handler with the request method instead of OPTIONS method. + if method == "OPTIONS" { + if v := r.Request.Header.Get("Access-Control-Request-Method"); v != "" { + method = v + } + } + // Search and cache the router handlers. + value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(method, r.URL.Path, r.GetHost()), func() interface{} { + parsedItems, hasHook, hasServe = s.searchHandlers(method, r.URL.Path, r.GetHost()) if parsedItems != nil { return &handlerCacheItem{parsedItems, hasHook, hasServe} } @@ -44,11 +53,6 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han if len(path) == 0 { return nil, false, false } - // Default domain has the most priority when iteration. - domains := []string{gDEFAULT_DOMAIN} - if !strings.EqualFold(gDEFAULT_DOMAIN, domain) { - domains = append(domains, domain) - } // Split the URL.path to separate parts. var array []string if strings.EqualFold("/", path) { @@ -58,11 +62,14 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han } parsedItemList := glist.New() lastMiddlewareElem := (*glist.Element)(nil) - for _, domain := range domains { + repeatHandlerCheckMap := make(map[int]struct{}, 16) + // Default domain has the most priority when iteration. + for _, domain := range []string{gDEFAULT_DOMAIN, domain} { p, ok := s.serveTree[domain] if !ok { continue } + // Make a list array with capacity of 16. lists := make([]*glist.List, 0, 16) for i, part := range array { // In case of double '/' URI, eg: /user//index @@ -72,8 +79,8 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) } - if _, ok := p.(map[string]interface{})[part]; ok { - p = p.(map[string]interface{})[part] + if v, ok := p.(map[string]interface{})[part]; ok { + p = v if i == len(array)-1 { if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) @@ -100,6 +107,12 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han for i := len(lists) - 1; i >= 0; i-- { for e := lists[i].Front(); e != nil; e = e.Next() { item := e.Value.(*handlerItem) + // 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数) + if _, ok := repeatHandlerCheckMap[item.itemId]; ok { + continue + } else { + repeatHandlerCheckMap[item.itemId] = struct{}{} + } // 服务路由函数只能添加一次,将重复判断放在这里提高检索效率 if hasServe { switch item.itemType { diff --git a/net/ghttp/ghttp_unit_middleware_test.go b/net/ghttp/ghttp_unit_middleware_basic_test.go similarity index 95% rename from net/ghttp/ghttp_unit_middleware_test.go rename to net/ghttp/ghttp_unit_middleware_basic_test.go index cfb56b4c1..0ff2c4beb 100644 --- a/net/ghttp/ghttp_unit_middleware_test.go +++ b/net/ghttp/ghttp_unit_middleware_basic_test.go @@ -594,8 +594,9 @@ func MiddlewareCORS(r *ghttp.Request) { func Test_Middleware_CORSAndAuth(t *testing.T) { p := ports.PopRand() s := g.Server(p) + s.Use(MiddlewareCORS) s.Group("/api.v2", func(group *ghttp.RouterGroup) { - group.Middleware(MiddlewareAuth, MiddlewareCORS) + group.Middleware(MiddlewareAuth) group.POST("/user/list", func(r *ghttp.Request) { r.Response.Write("list") }) @@ -680,3 +681,35 @@ func Test_Middleware_Scope(t *testing.T) { gtest.Assert(client.GetContent("/scope3"), "ae3fb") }) } + +func Test_Middleware_Panic(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + i := 0 + s.Group("/", func(group *ghttp.RouterGroup) { + group.Group("/", func(group *ghttp.RouterGroup) { + group.Middleware(func(r *ghttp.Request) { + i++ + panic("error") + r.Middleware.Next() + }, func(r *ghttp.Request) { + i++ + r.Middleware.Next() + }) + group.ALL("/", func(r *ghttp.Request) { + r.Response.Write(i) + }) + }) + }) + s.SetPort(p) + //s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + gtest.Assert(client.GetContent("/"), "error") + }) +} diff --git a/net/ghttp/ghttp_unit_middleware_cors_test.go b/net/ghttp/ghttp_unit_middleware_cors_test.go new file mode 100644 index 000000000..c99e16c97 --- /dev/null +++ b/net/ghttp/ghttp_unit_middleware_cors_test.go @@ -0,0 +1,76 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "fmt" + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" + "testing" + "time" +) + +func Test_Middleware_CORS(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.Group("/api.v2", func(group *ghttp.RouterGroup) { + group.Middleware(MiddlewareCORS) + group.POST("/user/list", func(r *ghttp.Request) { + r.Response.Write("list") + }) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + // Common Checks. + gtest.Assert(client.GetContent("/"), "Not Found") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") + + // GET request does not any route. + resp, err := client.Get("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + resp.Close() + + // POST request matches the route and CORS middleware. + resp, err = client.Post("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With") + gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE") + gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*") + gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800") + resp.Close() + }) + // OPTIONS GET + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "GET") + resp, err := client.Options("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + gtest.Assert(resp.ReadAllString(), "Not Found") + resp.Close() + }) + // OPTIONS POST + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "POST") + resp, err := client.Options("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + resp.Close() + }) +} diff --git a/net/ghttp/ghttp_unit_router_hook_test.go b/net/ghttp/ghttp_unit_router_hook_test.go index 1edc604fe..c3b592b79 100644 --- a/net/ghttp/ghttp_unit_router_hook_test.go +++ b/net/ghttp/ghttp_unit_router_hook_test.go @@ -50,7 +50,6 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { pattern1 := "/:name/info" s.BindHookHandlerByMap(pattern1, map[string]ghttp.HandlerFunc{ ghttp.HOOK_BEFORE_SERVE: func(r *ghttp.Request) { - fmt.Println("called") r.SetParam("uid", i) i++ }, @@ -59,17 +58,17 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { r.Response.Write(r.Get("uid")) }) - //pattern2 := "/{object}/list/{page}.java" - //s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ - // ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { - // r.Response.SetBuffer([]byte( - // fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), - // )) - // }, - //}) - //s.BindHandler(pattern2, func(r *ghttp.Request) { - // r.Response.Write(r.Router.Uri) - //}) + pattern2 := "/{object}/list/{page}.java" + s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ + ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { + r.Response.SetBuffer([]byte( + fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), + )) + }, + }) + s.BindHandler(pattern2, func(r *ghttp.Request) { + r.Response.Write(r.Router.Uri) + }) s.SetPort(p) //s.SetDumpRouterMap(false) s.Start()