improve CORS feature for ghttp.Server

This commit is contained in:
John 2020-03-04 22:52:56 +08:00
parent d8a7e36478
commit 4e7c6c1fb4
8 changed files with 196 additions and 89 deletions

View File

@ -123,6 +123,7 @@ func (m *Middleware) Next() {
}, func(exception interface{}) {
m.request.error = gerror.Newf("%v", exception)
m.request.Response.WriteStatus(http.StatusInternalServerError, exception)
loop = false
})
}
// Check the http status code after all handler and middleware done.

View File

@ -8,11 +8,10 @@
package ghttp
import (
"net/http"
"net/url"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"net/http"
"net/url"
)
// CORSOptions is the options for CORS feature.
@ -29,18 +28,18 @@ type CORSOptions struct {
var (
// defaultAllowHeaders is the default allowed headers for CORS.
// It's defined as map for better header key searching performance.
defaultAllowHeaders = map[string]struct{}{
"Origin": {},
"Accept": {},
"Cookie": {},
"Authorization": {},
"X-Auth-Token": {},
"X-Requested-With": {},
"Content-Type": {},
}
// It's defined another map for better header key searching performance.
defaultAllowHeaders = "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With"
defaultAllowHeadersMap = make(map[string]struct{})
)
func init() {
array := gstr.SplitAndTrim(defaultAllowHeaders, ",")
for _, header := range array {
defaultAllowHeadersMap[header] = struct{}{}
}
}
// DefaultCORSOptions returns the default CORS options,
// which allows any cross-domain request.
func (r *Response) DefaultCORSOptions() CORSOptions {
@ -48,22 +47,17 @@ func (r *Response) DefaultCORSOptions() CORSOptions {
AllowOrigin: "*",
AllowMethods: HTTP_METHODS,
AllowCredentials: "true",
AllowHeaders: defaultAllowHeaders,
MaxAge: 3628800,
}
// Allow all client's custom headers in default.
if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
array := gstr.SplitAndTrim(headers, ",")
for _, header := range array {
if _, ok := defaultAllowHeaders[header]; !ok {
if _, ok := defaultAllowHeadersMap[header]; !ok {
options.AllowHeaders += header + ","
}
}
for header, _ := range defaultAllowHeaders {
if len(options.AllowHeaders) > 0 {
options.AllowHeaders += ","
}
options.AllowHeaders += header
}
}
// Allow all anywhere origin in default.
if origin := r.Request.Header.Get("Origin"); origin != "" {
@ -101,8 +95,25 @@ func (r *Response) CORS(options CORSOptions) {
}
// No continue service handling if it's OPTIONS request.
if gstr.Equal(r.Request.Method, "OPTIONS") {
// Request method's handler searching.
// It here uses Server.routesMap attribute enhancing the searching performance.
if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" {
routerKey := ""
for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} {
for _, v := range []string{gDEFAULT_METHOD, method} {
routerKey = r.Server.handlerKey("", v, r.Request.URL.Path, domain)
if r.Server.routesMap[routerKey] != nil {
if r.Status == 0 {
r.Status = http.StatusOK
}
r.Request.ExitAll()
}
}
}
}
// Cannot find the request handler.
if r.Status == 0 {
r.Status = http.StatusOK
r.Status = http.StatusNotFound
}
r.Request.ExitAll()
}

View File

@ -79,6 +79,7 @@ type (
// handlerItem is the registered handler for route handling,
// including middleware and hook functions.
handlerItem struct {
itemId int // Unique handler item id mark.
itemName string // Handler name, which is automatically retrieved from runtime stack when registered.
itemType int // Handler type: object/handler/controller/middleware/hook.
itemFunc HandlerFunc // Handler address.

View File

@ -9,7 +9,7 @@ package ghttp
import (
"errors"
"fmt"
"github.com/gogf/gf/util/gutil"
"github.com/gogf/gf/container/gtype"
"strings"
"github.com/gogf/gf/debug/gdebug"
@ -23,6 +23,11 @@ const (
gFILTER_KEY = "/net/ghttp/ghttp"
)
var (
// handlerIdGenerator is handler item id generator.
handlerIdGenerator = gtype.NewInt()
)
// handlerKey creates and returns an unique router key for given parameters.
func (s *Server) handlerKey(hook, method, path, domain string) string {
return hook + "%" + s.serveHandlerKey(method, path, domain)
@ -59,6 +64,7 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err
// This function is called during server starts up, which cares little about the performance. What really cares
// is the well designed router storage structure for router searching when the request is under serving.
func (s *Server) setHandler(pattern string, handler *handlerItem) {
handler.itemId = handlerIdGenerator.Add(1)
domain, method, uri, err := s.parsePattern(pattern)
if err != nil {
s.Logger().Fatal("invalid pattern:", pattern, err)
@ -70,11 +76,11 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
}
// Repeated router checks, this feature can be disabled by server configuration.
regKey := s.handlerKey(handler.hookName, method, uri, domain)
routerKey := s.handlerKey(handler.hookName, method, uri, domain)
if !s.config.RouteOverWrite {
switch handler.itemType {
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
if item, ok := s.routesMap[regKey]; ok {
if item, ok := s.routesMap[routerKey]; ok {
s.Logger().Fatalf(`duplicated route registry "%s", already registered at %s`, pattern, item[0].file)
return
}
@ -143,47 +149,14 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
// fuzzy checks.
if i == len(array)-1 && part != "*fuzz" {
if v, ok := p.(map[string]interface{})["*list"]; !ok {
list := glist.New()
p.(map[string]interface{})["*list"] = list
lists = append(lists, list)
leafList := glist.New()
p.(map[string]interface{})["*list"] = leafList
lists = append(lists, leafList)
} else {
lists = append(lists, v.(*glist.List))
}
}
}
for k, v := range array {
if len(v) == 0 {
continue
}
// 判断是否模糊匹配规则
if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, v) {
v = "*fuzz"
// 由于是模糊规则,因此这里会有一个*list用以将后续的路由规则加进来
// 检索会从叶子节点的链表往根节点按照优先级进行检索
if v, ok := p.(map[string]interface{})["*list"]; !ok {
p.(map[string]interface{})["*list"] = glist.New()
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
} else {
lists = append(lists, v.(*glist.List))
}
}
// 属性层级数据写入
if _, ok := p.(map[string]interface{})[v]; !ok {
p.(map[string]interface{})[v] = make(map[string]interface{})
}
p = p.(map[string]interface{})[v]
// 到达叶子节点往list中增加匹配规则(条件 v != "*fuzz" 是因为模糊节点的话在前面已经添加了*list链表)
if k == len(array)-1 && v != "*fuzz" {
if v, ok := p.(map[string]interface{})["*list"]; !ok {
p.(map[string]interface{})["*list"] = glist.New()
lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List))
} else {
lists = append(lists, v.(*glist.List))
}
}
}
// It iterates the list array of <lists>, compares priorities and inserts the new router item in
// the proper position of each list. The priority of the list is ordered from high to low.
item := (*handlerItem)(nil)
@ -206,8 +179,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
}
}
// Initialize the route map item.
if _, ok := s.routesMap[regKey]; !ok {
s.routesMap[regKey] = make([]registeredRouteItem, 0)
if _, ok := s.routesMap[routerKey]; !ok {
s.routesMap[routerKey] = make([]registeredRouteItem, 0)
}
_, file, line := gdebug.CallerWithFilter(gFILTER_KEY)
routeItem := registeredRouteItem{
@ -217,12 +190,12 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) {
switch handler.itemType {
case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER:
// Overwrite the route.
s.routesMap[regKey] = []registeredRouteItem{routeItem}
s.routesMap[routerKey] = []registeredRouteItem{routeItem}
default:
// Append the route.
s.routesMap[regKey] = append(s.routesMap[regKey], routeItem)
s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem)
}
gutil.Dump(s.serveTree)
//gutil.Dump(s.serveTree)
}
// 对比两个handlerItem的优先级需要非常注意的是注意新老对比项的参数先后顺序。

View File

@ -15,7 +15,7 @@ import (
"github.com/gogf/gf/text/gregex"
)
// handlerCacheItem is a item for router cache.
// handlerCacheItem is an item for router searching cache.
type handlerCacheItem struct {
parsedItems []*handlerParsedItem
hasHook bool
@ -24,8 +24,17 @@ type handlerCacheItem struct {
// getHandlersWithCache searches the router item with cache feature for given request.
func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) {
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} {
parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost())
method := r.Method
// Special http method OPTIONS handling.
// It searches the handler with the request method instead of OPTIONS method.
if method == "OPTIONS" {
if v := r.Request.Header.Get("Access-Control-Request-Method"); v != "" {
method = v
}
}
// Search and cache the router handlers.
value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(method, r.URL.Path, r.GetHost()), func() interface{} {
parsedItems, hasHook, hasServe = s.searchHandlers(method, r.URL.Path, r.GetHost())
if parsedItems != nil {
return &handlerCacheItem{parsedItems, hasHook, hasServe}
}
@ -44,11 +53,6 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
if len(path) == 0 {
return nil, false, false
}
// Default domain has the most priority when iteration.
domains := []string{gDEFAULT_DOMAIN}
if !strings.EqualFold(gDEFAULT_DOMAIN, domain) {
domains = append(domains, domain)
}
// Split the URL.path to separate parts.
var array []string
if strings.EqualFold("/", path) {
@ -58,11 +62,14 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
}
parsedItemList := glist.New()
lastMiddlewareElem := (*glist.Element)(nil)
for _, domain := range domains {
repeatHandlerCheckMap := make(map[int]struct{}, 16)
// Default domain has the most priority when iteration.
for _, domain := range []string{gDEFAULT_DOMAIN, domain} {
p, ok := s.serveTree[domain]
if !ok {
continue
}
// Make a list array with capacity of 16.
lists := make([]*glist.List, 0, 16)
for i, part := range array {
// In case of double '/' URI, eg: /user//index
@ -72,8 +79,8 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
if v, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, v.(*glist.List))
}
if _, ok := p.(map[string]interface{})[part]; ok {
p = p.(map[string]interface{})[part]
if v, ok := p.(map[string]interface{})[part]; ok {
p = v
if i == len(array)-1 {
if v, ok := p.(map[string]interface{})["*list"]; ok {
lists = append(lists, v.(*glist.List))
@ -100,6 +107,12 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
for i := len(lists) - 1; i >= 0; i-- {
for e := lists[i].Front(); e != nil; e = e.Next() {
item := e.Value.(*handlerItem)
// 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数)
if _, ok := repeatHandlerCheckMap[item.itemId]; ok {
continue
} else {
repeatHandlerCheckMap[item.itemId] = struct{}{}
}
// 服务路由函数只能添加一次,将重复判断放在这里提高检索效率
if hasServe {
switch item.itemType {

View File

@ -594,8 +594,9 @@ func MiddlewareCORS(r *ghttp.Request) {
func Test_Middleware_CORSAndAuth(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Use(MiddlewareCORS)
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
group.Middleware(MiddlewareAuth, MiddlewareCORS)
group.Middleware(MiddlewareAuth)
group.POST("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
@ -680,3 +681,35 @@ func Test_Middleware_Scope(t *testing.T) {
gtest.Assert(client.GetContent("/scope3"), "ae3fb")
})
}
func Test_Middleware_Panic(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
i := 0
s.Group("/", func(group *ghttp.RouterGroup) {
group.Group("/", func(group *ghttp.RouterGroup) {
group.Middleware(func(r *ghttp.Request) {
i++
panic("error")
r.Middleware.Next()
}, func(r *ghttp.Request) {
i++
r.Middleware.Next()
})
group.ALL("/", func(r *ghttp.Request) {
r.Response.Write(i)
})
})
})
s.SetPort(p)
//s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "error")
})
}

View File

@ -0,0 +1,76 @@
// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp_test
import (
"fmt"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
"github.com/gogf/gf/test/gtest"
"testing"
"time"
)
func Test_Middleware_CORS(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
group.Middleware(MiddlewareCORS)
group.POST("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
// Common Checks.
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
// GET request does not any route.
resp, err := client.Get("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
resp.Close()
// POST request matches the route and CORS middleware.
resp, err = client.Post("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With")
gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE")
gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*")
gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800")
resp.Close()
})
// OPTIONS GET
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
client.SetHeader("Access-Control-Request-Method", "GET")
resp, err := client.Options("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
gtest.Assert(resp.ReadAllString(), "Not Found")
resp.Close()
})
// OPTIONS POST
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
client.SetHeader("Access-Control-Request-Method", "POST")
resp, err := client.Options("/api.v2/user/list")
gtest.Assert(err, nil)
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
resp.Close()
})
}

View File

@ -50,7 +50,6 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) {
pattern1 := "/:name/info"
s.BindHookHandlerByMap(pattern1, map[string]ghttp.HandlerFunc{
ghttp.HOOK_BEFORE_SERVE: func(r *ghttp.Request) {
fmt.Println("called")
r.SetParam("uid", i)
i++
},
@ -59,17 +58,17 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) {
r.Response.Write(r.Get("uid"))
})
//pattern2 := "/{object}/list/{page}.java"
//s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{
// ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) {
// r.Response.SetBuffer([]byte(
// fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i),
// ))
// },
//})
//s.BindHandler(pattern2, func(r *ghttp.Request) {
// r.Response.Write(r.Router.Uri)
//})
pattern2 := "/{object}/list/{page}.java"
s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{
ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) {
r.Response.SetBuffer([]byte(
fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i),
))
},
})
s.BindHandler(pattern2, func(r *ghttp.Request) {
r.Response.Write(r.Router.Uri)
})
s.SetPort(p)
//s.SetDumpRouterMap(false)
s.Start()