feat(oidc): add automatic allow all cors to discovery (#2953)

This adds a Cross Origin Resource Sharing policy that automatically allows any cross-origin request to the OpenID Connect discovery documents.
pull/2951/head^2
James Elliott 2022-03-04 15:46:12 +11:00 committed by GitHub
parent a5c400cb1d
commit a8f5a70b03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 138 additions and 2 deletions

View File

@ -10,8 +10,8 @@ import (
// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
// TODO: Add OPTIONS handler.
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(wellKnownOpenIDConnectConfigurationGET))
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(wellKnownOAuthAuthorizationServerGET))
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))

View File

@ -15,6 +15,24 @@ var (
headerXOriginalURL = []byte("X-Original-URL")
headerXForwardedMethod = []byte("X-Forwarded-Method")
headerVary = []byte(fasthttp.HeaderVary)
headerOrigin = []byte(fasthttp.HeaderOrigin)
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
headerAccessControlAllowOrigin = []byte(fasthttp.HeaderAccessControlAllowOrigin)
headerAccessControlMaxAge = []byte(fasthttp.HeaderAccessControlMaxAge)
headerAccessControlRequestHeaders = []byte(fasthttp.HeaderAccessControlRequestHeaders)
headerAccessControlRequestMethod = []byte(fasthttp.HeaderAccessControlRequestMethod)
)
var (
headerValueFalse = []byte("false")
headerValueMaxAge = []byte("100")
headerValueVary = []byte("Accept-Encoding, Origin")
)
var (
protoHTTPS = []byte("https")
protoHTTP = []byte("http")

View File

@ -0,0 +1,53 @@
package middlewares
import (
"net/url"
"strings"
"github.com/valyala/fasthttp"
)
// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well
// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied
// to both Accept-Encoding and Origin. It grants the GET Request Method only.
func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil {
corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin)
}
next(ctx)
}
}
func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) {
originURL, err := url.Parse(string(origin))
if err != nil || originURL.Scheme != "https" {
return
}
resp.Header.SetBytesKV(headerVary, headerValueVary)
resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
requestedHeaders := strings.Split(string(headers), ",")
allowHeaders := make([]string, len(requestedHeaders))
for i, header := range requestedHeaders {
headerTrimmed := strings.Trim(header, " ")
if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
allowHeaders[i] = headerTrimmed
}
}
if len(allowHeaders) != 0 {
resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
}
}
if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
}
}

View File

@ -0,0 +1,65 @@
package middlewares
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
)
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
req := fasthttp.AcquireRequest()
resp := fasthttp.Response{}
origin := []byte("https://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
req := fasthttp.AcquireRequest()
resp := fasthttp.Response{}
origin := []byte("https://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
req := fasthttp.AcquireRequest()
resp := fasthttp.Response{}
origin := []byte("http://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
}