diff --git a/internal/handlers/oidc_register.go b/internal/handlers/oidc_register.go index 63af0362c..58646da7d 100644 --- a/internal/handlers/oidc_register.go +++ b/internal/handlers/oidc_register.go @@ -10,8 +10,8 @@ import ( // RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout. func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) { // TODO: Add OPTIONS handler. - router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(wellKnownOpenIDConnectConfigurationGET)) - router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(wellKnownOAuthAuthorizationServerGET)) + router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET))) + router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET))) router.GET(pathOpenIDConnectConsent, middleware(oidcConsent)) diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index dbe37e6f5..54b350a7e 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -15,6 +15,24 @@ var ( headerXOriginalURL = []byte("X-Original-URL") headerXForwardedMethod = []byte("X-Forwarded-Method") + headerVary = []byte(fasthttp.HeaderVary) + headerOrigin = []byte(fasthttp.HeaderOrigin) + headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials) + headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders) + headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods) + headerAccessControlAllowOrigin = []byte(fasthttp.HeaderAccessControlAllowOrigin) + headerAccessControlMaxAge = []byte(fasthttp.HeaderAccessControlMaxAge) + headerAccessControlRequestHeaders = []byte(fasthttp.HeaderAccessControlRequestHeaders) + headerAccessControlRequestMethod = []byte(fasthttp.HeaderAccessControlRequestMethod) +) + +var ( + headerValueFalse = []byte("false") + headerValueMaxAge = []byte("100") + headerValueVary = []byte("Accept-Encoding, Origin") +) + +var ( protoHTTPS = []byte("https") protoHTTP = []byte("http") diff --git a/internal/middlewares/cors.go b/internal/middlewares/cors.go new file mode 100644 index 000000000..a4152dcf1 --- /dev/null +++ b/internal/middlewares/cors.go @@ -0,0 +1,53 @@ +package middlewares + +import ( + "net/url" + "strings" + + "github.com/valyala/fasthttp" +) + +// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well +// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied +// to both Accept-Encoding and Origin. It grants the GET Request Method only. +func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler { + return func(ctx *AutheliaCtx) { + if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil { + corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin) + } + + next(ctx) + } +} + +func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) { + originURL, err := url.Parse(string(origin)) + if err != nil || originURL.Scheme != "https" { + return + } + + resp.Header.SetBytesKV(headerVary, headerValueVary) + resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin) + resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse) + resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge) + + if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil { + requestedHeaders := strings.Split(string(headers), ",") + allowHeaders := make([]string, len(requestedHeaders)) + + for i, header := range requestedHeaders { + headerTrimmed := strings.Trim(header, " ") + if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) { + allowHeaders[i] = headerTrimmed + } + } + + if len(allowHeaders) != 0 { + resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", "))) + } + } + + if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil { + resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods) + } +} diff --git a/internal/middlewares/cors_test.go b/internal/middlewares/cors_test.go new file mode 100644 index 000000000..f44106aee --- /dev/null +++ b/internal/middlewares/cors_test.go @@ -0,0 +1,65 @@ +package middlewares + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" +) + +func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) { + req := fasthttp.AcquireRequest() + resp := fasthttp.Response{} + + origin := []byte("https://myapp.example.com") + + req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + + corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + + assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) { + req := fasthttp.AcquireRequest() + resp := fasthttp.Response{} + + origin := []byte("https://myapp.example.com") + + req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") + + corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + + assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) { + req := fasthttp.AcquireRequest() + + resp := fasthttp.Response{} + + origin := []byte("http://myapp.example.com") + + req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") + + corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) +}