348 lines
10 KiB
Go
348 lines
10 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"bytes"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
|
|
// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
|
|
// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
|
|
// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
|
|
// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
|
|
// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
|
|
// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
|
|
// the Access-Control-Max-Age header to 100.
|
|
//
|
|
// These behaviours can be overridden by the With methods on the returned policy.
|
|
func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
|
|
return &CORSPolicyBuilder{
|
|
enabled: true,
|
|
maxAge: 100,
|
|
}
|
|
}
|
|
|
|
// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
|
|
// configured. It aims to simplify CORS configurations.
|
|
type CORSPolicyBuilder struct {
|
|
enabled bool
|
|
varyOnly bool
|
|
varySet bool
|
|
methods []string
|
|
headers []string
|
|
origins []string
|
|
credentials bool
|
|
vary []string
|
|
maxAge int
|
|
}
|
|
|
|
// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
|
|
func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
|
|
policy = &CORSPolicy{
|
|
enabled: b.enabled,
|
|
varyOnly: b.varyOnly,
|
|
credentials: []byte(strconv.FormatBool(b.credentials)),
|
|
origins: b.buildOrigins(),
|
|
headers: b.buildHeaders(),
|
|
vary: b.buildVary(),
|
|
}
|
|
|
|
if len(b.methods) != 0 {
|
|
policy.methods = []byte(strings.Join(b.methods, ", "))
|
|
}
|
|
|
|
if b.maxAge <= 0 {
|
|
policy.maxAge = headerValueMaxAge
|
|
} else {
|
|
policy.maxAge = []byte(strconv.Itoa(b.maxAge))
|
|
}
|
|
|
|
return policy
|
|
}
|
|
|
|
func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
|
|
if len(b.origins) != 0 {
|
|
if len(b.origins) == 1 && b.origins[0] == "*" {
|
|
origins = append(origins, []byte(b.origins[0]))
|
|
} else {
|
|
for _, origin := range b.origins {
|
|
origins = append(origins, []byte(origin))
|
|
}
|
|
}
|
|
}
|
|
|
|
return origins
|
|
}
|
|
|
|
func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
|
|
if len(b.headers) != 0 {
|
|
h := b.headers
|
|
|
|
if b.credentials {
|
|
if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
|
|
h = append(h, fasthttp.HeaderCookie)
|
|
}
|
|
|
|
if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
|
|
h = append(h, fasthttp.HeaderAuthorization)
|
|
}
|
|
|
|
if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
|
|
h = append(h, fasthttp.HeaderProxyAuthorization)
|
|
}
|
|
}
|
|
|
|
headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
|
|
}
|
|
|
|
return headers
|
|
}
|
|
|
|
func (b CORSPolicyBuilder) buildVary() (vary []byte) {
|
|
if b.varySet {
|
|
if len(b.vary) != 0 {
|
|
vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
|
|
}
|
|
} else {
|
|
if len(b.origins) == 1 && b.origins[0] == "*" {
|
|
vary = headerValueVaryWildcard
|
|
} else {
|
|
vary = headerValueVary
|
|
}
|
|
}
|
|
|
|
return vary
|
|
}
|
|
|
|
// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
|
|
// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
|
|
// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
|
|
// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
|
|
// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
|
|
// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
|
|
func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
|
|
b.enabled = enabled
|
|
|
|
return b
|
|
}
|
|
|
|
// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
|
|
// that value.
|
|
func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
|
|
b.methods = methods
|
|
|
|
return b
|
|
}
|
|
|
|
// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
|
|
func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
|
|
b.origins = origins
|
|
|
|
return b
|
|
}
|
|
|
|
// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
|
|
func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
|
|
b.headers = headers
|
|
|
|
return b
|
|
}
|
|
|
|
// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
|
|
func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
|
|
b.credentials = allow
|
|
|
|
return b
|
|
}
|
|
|
|
// WithVary takes a list of header strings and alters the default Vary header.
|
|
func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
|
|
b.vary = headers
|
|
b.varySet = true
|
|
|
|
return b
|
|
}
|
|
|
|
// WithVaryOnly just adds the Vary header.
|
|
func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
|
|
b.varyOnly = varyOnly
|
|
|
|
return b
|
|
}
|
|
|
|
// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
|
|
func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
|
|
b.maxAge = age
|
|
|
|
return b
|
|
}
|
|
|
|
// CORSPolicy is a middleware that handles adding CORS headers.
|
|
type CORSPolicy struct {
|
|
enabled bool
|
|
varyOnly bool
|
|
methods []byte
|
|
headers []byte
|
|
origins [][]byte
|
|
credentials []byte
|
|
vary []byte
|
|
maxAge []byte
|
|
}
|
|
|
|
// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
|
|
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
|
func (p *CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
|
|
p.handleOPTIONS(ctx)
|
|
p.handle(ctx)
|
|
}
|
|
|
|
// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
|
|
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
|
func (p *CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
|
|
p.handleOPTIONS(ctx)
|
|
}
|
|
|
|
// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
|
|
func (p *CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
p.handle(ctx)
|
|
|
|
next(ctx)
|
|
}
|
|
}
|
|
|
|
func (p *CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
|
|
if !p.enabled {
|
|
return
|
|
}
|
|
|
|
p.handleVary(ctx)
|
|
|
|
if !p.varyOnly {
|
|
p.handleCORS(ctx)
|
|
}
|
|
}
|
|
|
|
func (p *CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
|
|
ctx.Response.ResetBody()
|
|
|
|
/* The OPTIONS method should not return a 204 as per the following specifications when read together:
|
|
|
|
RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
|
|
A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
|
|
the response.
|
|
|
|
RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
|
|
A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
|
|
or 204 (No Content).
|
|
*/
|
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
|
ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)
|
|
|
|
if len(p.methods) != 0 {
|
|
ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
|
|
}
|
|
}
|
|
|
|
func (p *CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
|
|
if len(p.vary) != 0 {
|
|
ctx.Response.Header.SetBytesKV(headerVary, p.vary)
|
|
}
|
|
}
|
|
|
|
func (p *CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
|
|
var (
|
|
originURL *url.URL
|
|
err error
|
|
)
|
|
|
|
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
|
|
|
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
|
if originURL, err = url.ParseRequestURI(string(origin)); err != nil || (originURL.Scheme != strProtoHTTPS && p.origins == nil) {
|
|
return
|
|
}
|
|
|
|
var allowedOrigin []byte
|
|
|
|
switch len(p.origins) {
|
|
case 0:
|
|
allowedOrigin = origin
|
|
default:
|
|
for i := 0; i < len(p.origins); i++ {
|
|
if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
|
|
allowedOrigin = headerValueOriginWildcard
|
|
} else if bytes.Equal(p.origins[i], origin) {
|
|
allowedOrigin = origin
|
|
}
|
|
}
|
|
|
|
if len(allowedOrigin) == 0 {
|
|
return
|
|
}
|
|
}
|
|
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)
|
|
|
|
if len(p.credentials) != 0 {
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
|
|
}
|
|
|
|
if len(p.maxAge) != 0 {
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
|
|
}
|
|
|
|
p.handleAllowedHeaders(ctx)
|
|
p.handleAllowedMethods(ctx)
|
|
}
|
|
|
|
func (p *CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
|
|
switch len(p.methods) {
|
|
case 0:
|
|
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
|
if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
|
}
|
|
default:
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
|
|
}
|
|
}
|
|
|
|
func (p *CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
|
|
switch len(p.headers) {
|
|
case 0:
|
|
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
|
if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
|
requestedHeaders := strings.Split(string(headers), ",")
|
|
allowHeaders := make([]string, 0, len(requestedHeaders))
|
|
|
|
for i := 0; i < len(requestedHeaders); i++ {
|
|
headerTrimmed := strings.Trim(requestedHeaders[i], " ")
|
|
|
|
if headerTrimmed == "*" {
|
|
continue
|
|
}
|
|
|
|
if bytes.Equal(p.credentials, headerValueTrue) ||
|
|
(!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
|
|
!strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
|
|
!strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
|
|
allowHeaders = append(allowHeaders, headerTrimmed)
|
|
}
|
|
}
|
|
|
|
if len(allowHeaders) != 0 {
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
|
}
|
|
}
|
|
default:
|
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
|
|
}
|
|
}
|