diff --git a/internal/middlewares/asset_override.go b/internal/middlewares/asset_override.go index 887d60117..e25444aba 100644 --- a/internal/middlewares/asset_override.go +++ b/internal/middlewares/asset_override.go @@ -7,8 +7,8 @@ import ( "github.com/valyala/fasthttp" ) -// AssetOverrideMiddleware allows overriding and serving of specific embedded assets from disk. -func AssetOverrideMiddleware(root string, strip int, next fasthttp.RequestHandler) fasthttp.RequestHandler { +// AssetOverride allows overriding and serving of specific embedded assets from disk. +func AssetOverride(root string, strip int, next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { if root == "" { next(ctx) diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index b0475281d..314d0afcd 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -30,6 +30,10 @@ var ( headerAccessControlMaxAge = []byte(fasthttp.HeaderAccessControlMaxAge) headerAccessControlRequestHeaders = []byte(fasthttp.HeaderAccessControlRequestHeaders) headerAccessControlRequestMethod = []byte(fasthttp.HeaderAccessControlRequestMethod) + + headerXContentTypeOptions = []byte(fasthttp.HeaderXContentTypeOptions) + headerReferrerPolicy = []byte(fasthttp.HeaderReferrerPolicy) + headerPermissionsPolicy = []byte("Permissions-Policy") ) var ( @@ -40,6 +44,10 @@ var ( headerValueVaryWildcard = []byte("Accept-Encoding") headerValueOriginWildcard = []byte("*") headerValueZero = []byte("0") + + headerValueNoSniff = []byte("nosniff") + headerValueStrictOriginCrossOrigin = []byte("strict-origin-when-cross-origin") + headerValueCohort = []byte("interest-cohort=()") ) var ( diff --git a/internal/middlewares/headers.go b/internal/middlewares/headers.go new file mode 100644 index 000000000..abb6e4b59 --- /dev/null +++ b/internal/middlewares/headers.go @@ -0,0 +1,16 @@ +package middlewares + +import ( + "github.com/valyala/fasthttp" +) + +// SecurityHeaders middleware adds several modern recommended security headers with safe values. +func SecurityHeaders(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.Response.Header.SetBytesKV(headerXContentTypeOptions, headerValueNoSniff) + ctx.Response.Header.SetBytesKV(headerReferrerPolicy, headerValueStrictOriginCrossOrigin) + ctx.Response.Header.SetBytesKV(headerPermissionsPolicy, headerValueCohort) + + next(ctx) + } +} diff --git a/internal/middlewares/log_request.go b/internal/middlewares/log_request.go index 7a1e9cd83..1790ba7dc 100644 --- a/internal/middlewares/log_request.go +++ b/internal/middlewares/log_request.go @@ -4,8 +4,8 @@ import ( "github.com/valyala/fasthttp" ) -// LogRequestMiddleware logs the query that is being treated. -func LogRequestMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler { +// LogRequest logs the query that is being treated. +func LogRequest(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { autheliaCtx := &AutheliaCtx{RequestCtx: ctx} logger := NewRequestLogger(autheliaCtx) diff --git a/internal/middlewares/log_request_test.go b/internal/middlewares/log_request_test.go index ee32b986b..454430e28 100644 --- a/internal/middlewares/log_request_test.go +++ b/internal/middlewares/log_request_test.go @@ -13,7 +13,7 @@ func TestShouldCallNextFunction(t *testing.T) { f := func(ctx *fasthttp.RequestCtx) { val = true } context := &fasthttp.RequestCtx{} - LogRequestMiddleware(f)(context) + LogRequest(f)(context) assert.Equal(t, true, val) } diff --git a/internal/middlewares/strip_path.go b/internal/middlewares/strip_path.go index 8079bb420..5f581696d 100644 --- a/internal/middlewares/strip_path.go +++ b/internal/middlewares/strip_path.go @@ -6,8 +6,8 @@ import ( "github.com/valyala/fasthttp" ) -// StripPathMiddleware strips the first level of a path. -func StripPathMiddleware(path string, next fasthttp.RequestHandler) fasthttp.RequestHandler { +// StripPath strips the first level of a path. +func StripPath(path string, next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { uri := ctx.RequestURI() diff --git a/internal/server/handlers.go b/internal/server/handlers.go index b7f726823..6f66d5847 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -116,13 +116,13 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa r.GET("/"+f, handlerPublicHTML) } - r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(config.Server.AssetPath, 0, handlerPublicHTML)) - r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(config.Server.AssetPath, 2, handlerPublicHTML)) + r.GET("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML)) + r.GET("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML)) r.GET("/static/{filepath:*}", handlerPublicHTML) // Locales. - r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(config.Server.AssetPath, 0, handlerLocales)) - r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(config.Server.AssetPath, 0, handlerLocales)) + r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales)) + r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales)) // Swagger. r.GET("/api/", middleware(serveSwaggerHandler)) @@ -298,9 +298,9 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa r.HandleMethodNotAllowed = true r.MethodNotAllowed = handlerMethodNotAllowed - handler := middlewares.LogRequestMiddleware(r.Handler) + handler := middlewares.LogRequest(middlewares.SecurityHeaders(r.Handler)) if config.Server.Path != "" { - handler = middlewares.StripPathMiddleware(config.Server.Path, handler) + handler = middlewares.StripPath(config.Server.Path, handler) } return handler