mirror of
https://gitee.com/johng/gf.git
synced 2024-12-03 12:47:50 +08:00
improve CORS feature for ghttp.Server
This commit is contained in:
parent
f3d859159d
commit
33ae93e050
@ -94,28 +94,13 @@ func (r *Response) CORS(options CORSOptions) {
|
|||||||
r.Header().Set("Access-Control-Allow-Headers", options.AllowHeaders)
|
r.Header().Set("Access-Control-Allow-Headers", options.AllowHeaders)
|
||||||
}
|
}
|
||||||
// No continue service handling if it's OPTIONS request.
|
// No continue service handling if it's OPTIONS request.
|
||||||
|
// Note that there's special checks in previous router searching,
|
||||||
|
// so if it goes to here it means there's already serving handler exist.
|
||||||
if gstr.Equal(r.Request.Method, "OPTIONS") {
|
if gstr.Equal(r.Request.Method, "OPTIONS") {
|
||||||
// Request method handler searching.
|
|
||||||
// It here simply uses Server.routesMap attribute enhancing the searching performance.
|
|
||||||
if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" {
|
|
||||||
routerKey := ""
|
|
||||||
for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} {
|
|
||||||
for _, v := range []string{gDEFAULT_METHOD, method} {
|
|
||||||
routerKey = r.Server.routerMapKey("", v, r.Request.URL.Path, domain)
|
|
||||||
if r.Server.routesMap[routerKey] != nil {
|
|
||||||
if r.Status == 0 {
|
|
||||||
r.Status = http.StatusOK
|
|
||||||
}
|
|
||||||
// No continue serving.
|
|
||||||
r.Request.ExitAll()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Cannot find the request serving handler, it then responses 404.
|
|
||||||
if r.Status == 0 {
|
if r.Status == 0 {
|
||||||
r.Status = http.StatusNotFound
|
r.Status = http.StatusOK
|
||||||
}
|
}
|
||||||
|
// No continue serving.
|
||||||
r.Request.ExitAll()
|
r.Request.ExitAll()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Middleware_CORS(t *testing.T) {
|
func Test_Middleware_CORS1(t *testing.T) {
|
||||||
p := ports.PopRand()
|
p := ports.PopRand()
|
||||||
s := g.Server(p)
|
s := g.Server(p)
|
||||||
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
||||||
@ -77,3 +77,69 @@ func Test_Middleware_CORS(t *testing.T) {
|
|||||||
resp.Close()
|
resp.Close()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_Middleware_CORS2(t *testing.T) {
|
||||||
|
p := ports.PopRand()
|
||||||
|
s := g.Server(p)
|
||||||
|
s.Group("/api.v2", func(group *ghttp.RouterGroup) {
|
||||||
|
group.Middleware(MiddlewareCORS)
|
||||||
|
group.GET("/user/list/{type}", func(r *ghttp.Request) {
|
||||||
|
r.Response.Write(r.Get("type"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
s.SetPort(p)
|
||||||
|
s.SetDumpRouterMap(false)
|
||||||
|
s.Start()
|
||||||
|
defer s.Shutdown()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
gtest.Case(t, func() {
|
||||||
|
client := ghttp.NewClient()
|
||||||
|
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||||
|
// Common Checks.
|
||||||
|
gtest.Assert(client.GetContent("/"), "Not Found")
|
||||||
|
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
|
||||||
|
// Get request.
|
||||||
|
resp, err := client.Get("/api.v2/user/list/1")
|
||||||
|
gtest.Assert(err, nil)
|
||||||
|
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
|
||||||
|
gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With")
|
||||||
|
gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE")
|
||||||
|
gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*")
|
||||||
|
gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800")
|
||||||
|
gtest.Assert(resp.ReadAllString(), "1")
|
||||||
|
resp.Close()
|
||||||
|
})
|
||||||
|
// OPTIONS GET None.
|
||||||
|
gtest.Case(t, func() {
|
||||||
|
client := ghttp.NewClient()
|
||||||
|
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||||
|
client.SetHeader("Access-Control-Request-Method", "GET")
|
||||||
|
resp, err := client.Options("/api.v2/user")
|
||||||
|
gtest.Assert(err, nil)
|
||||||
|
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
|
||||||
|
gtest.Assert(resp.StatusCode, 404)
|
||||||
|
resp.Close()
|
||||||
|
})
|
||||||
|
// OPTIONS GET
|
||||||
|
gtest.Case(t, func() {
|
||||||
|
client := ghttp.NewClient()
|
||||||
|
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||||
|
client.SetHeader("Access-Control-Request-Method", "GET")
|
||||||
|
resp, err := client.Options("/api.v2/user/list/1")
|
||||||
|
gtest.Assert(err, nil)
|
||||||
|
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1)
|
||||||
|
gtest.Assert(resp.StatusCode, 200)
|
||||||
|
resp.Close()
|
||||||
|
})
|
||||||
|
// OPTIONS POST
|
||||||
|
gtest.Case(t, func() {
|
||||||
|
client := ghttp.NewClient()
|
||||||
|
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||||
|
client.SetHeader("Access-Control-Request-Method", "POST")
|
||||||
|
resp, err := client.Options("/api.v2/user/list/1")
|
||||||
|
gtest.Assert(err, nil)
|
||||||
|
gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0)
|
||||||
|
gtest.Assert(resp.StatusCode, 404)
|
||||||
|
resp.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user