diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index 824ed5500..df2bc5f6f 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -94,28 +94,13 @@ func (r *Response) CORS(options CORSOptions) { r.Header().Set("Access-Control-Allow-Headers", options.AllowHeaders) } // No continue service handling if it's OPTIONS request. + // Note that there's special checks in previous router searching, + // so if it goes to here it means there's already serving handler exist. if gstr.Equal(r.Request.Method, "OPTIONS") { - // Request method handler searching. - // It here simply uses Server.routesMap attribute enhancing the searching performance. - if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" { - routerKey := "" - for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} { - for _, v := range []string{gDEFAULT_METHOD, method} { - routerKey = r.Server.routerMapKey("", v, r.Request.URL.Path, domain) - if r.Server.routesMap[routerKey] != nil { - if r.Status == 0 { - r.Status = http.StatusOK - } - // No continue serving. - r.Request.ExitAll() - } - } - } - } - // Cannot find the request serving handler, it then responses 404. if r.Status == 0 { - r.Status = http.StatusNotFound + r.Status = http.StatusOK } + // No continue serving. r.Request.ExitAll() } } diff --git a/net/ghttp/ghttp_unit_middleware_cors_test.go b/net/ghttp/ghttp_unit_middleware_cors_test.go index 4044d59e6..c5d24529e 100644 --- a/net/ghttp/ghttp_unit_middleware_cors_test.go +++ b/net/ghttp/ghttp_unit_middleware_cors_test.go @@ -15,7 +15,7 @@ import ( "time" ) -func Test_Middleware_CORS(t *testing.T) { +func Test_Middleware_CORS1(t *testing.T) { p := ports.PopRand() s := g.Server(p) s.Group("/api.v2", func(group *ghttp.RouterGroup) { @@ -77,3 +77,69 @@ func Test_Middleware_CORS(t *testing.T) { resp.Close() }) } + +func Test_Middleware_CORS2(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.Group("/api.v2", func(group *ghttp.RouterGroup) { + group.Middleware(MiddlewareCORS) + group.GET("/user/list/{type}", func(r *ghttp.Request) { + r.Response.Write(r.Get("type")) + }) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + // Common Checks. + gtest.Assert(client.GetContent("/"), "Not Found") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") + // Get request. + resp, err := client.Get("/api.v2/user/list/1") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With") + gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE") + gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*") + gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800") + gtest.Assert(resp.ReadAllString(), "1") + resp.Close() + }) + // OPTIONS GET None. + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "GET") + resp, err := client.Options("/api.v2/user") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + gtest.Assert(resp.StatusCode, 404) + resp.Close() + }) + // OPTIONS GET + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "GET") + resp, err := client.Options("/api.v2/user/list/1") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + gtest.Assert(resp.StatusCode, 200) + resp.Close() + }) + // OPTIONS POST + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "POST") + resp, err := client.Options("/api.v2/user/list/1") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + gtest.Assert(resp.StatusCode, 404) + resp.Close() + }) +}