diff --git a/internal/handlers/handler_configuration_password_policy.go b/internal/handlers/handler_configuration_password_policy.go index 9e827ef6d..3423eec54 100644 --- a/internal/handlers/handler_configuration_password_policy.go +++ b/internal/handlers/handler_configuration_password_policy.go @@ -4,8 +4,8 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -// PasswordPolicyConfigurationGet get the password policy configuration. -func PasswordPolicyConfigurationGet(ctx *middlewares.AutheliaCtx) { +// PasswordPolicyConfigurationGET get the password policy configuration. +func PasswordPolicyConfigurationGET(ctx *middlewares.AutheliaCtx) { policyResponse := PassworPolicyBody{ Mode: "disabled", } diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index e1b7d9a9d..28edbd3fd 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -29,36 +29,15 @@ func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry { } // NewAutheliaCtx instantiate an AutheliaCtx out of a RequestCtx. -func NewAutheliaCtx(ctx *fasthttp.RequestCtx, configuration schema.Configuration, providers Providers) (*AutheliaCtx, error) { - autheliaCtx := new(AutheliaCtx) - autheliaCtx.RequestCtx = ctx - autheliaCtx.Providers = providers - autheliaCtx.Configuration = configuration - autheliaCtx.Logger = NewRequestLogger(autheliaCtx) - autheliaCtx.Clock = utils.RealClock{} +func NewAutheliaCtx(requestCTX *fasthttp.RequestCtx, configuration schema.Configuration, providers Providers) (ctx *AutheliaCtx) { + ctx = new(AutheliaCtx) + ctx.RequestCtx = requestCTX + ctx.Providers = providers + ctx.Configuration = configuration + ctx.Logger = NewRequestLogger(ctx) + ctx.Clock = utils.RealClock{} - return autheliaCtx, nil -} - -// AutheliaMiddleware is wrapping the RequestCtx into an AutheliaCtx providing Authelia related objects. -func AutheliaMiddleware(configuration schema.Configuration, providers Providers, middlewares ...StandardMiddleware) RequestHandlerBridge { - return func(next RequestHandler) fasthttp.RequestHandler { - bridge := func(ctx *fasthttp.RequestCtx) { - autheliaCtx, err := NewAutheliaCtx(ctx, configuration, providers) - if err != nil { - autheliaCtx.Error(err, messageOperationFailed) - return - } - - next(autheliaCtx) - } - - for i := len(middlewares) - 1; i >= 0; i-- { - bridge = middlewares[i](bridge) - } - - return bridge - } + return ctx } // AvailableSecondFactorMethods returns the available 2FA methods. diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go index e120b8c29..b15b5eb77 100644 --- a/internal/middlewares/authelia_context_test.go +++ b/internal/middlewares/authelia_context_test.go @@ -27,7 +27,9 @@ func TestShouldCallNextWithAutheliaCtx(t *testing.T) { } nextCalled := false - middlewares.AutheliaMiddleware(configuration, providers)(func(actx *middlewares.AutheliaCtx) { + middleware := middlewares.NewBridgeBuilder(configuration, providers).Build() + + middleware(func(actx *middlewares.AutheliaCtx) { // Authelia context wraps the request. assert.Equal(t, ctx, actx.RequestCtx) nextCalled = true diff --git a/internal/middlewares/bridge.go b/internal/middlewares/bridge.go new file mode 100644 index 000000000..5b4390566 --- /dev/null +++ b/internal/middlewares/bridge.go @@ -0,0 +1,63 @@ +package middlewares + +import ( + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/configuration/schema" +) + +// NewBridgeBuilder creates a new BridgeBuilder. +func NewBridgeBuilder(config schema.Configuration, providers Providers) *BridgeBuilder { + return &BridgeBuilder{ + config: config, + providers: providers, + } +} + +// WithConfig sets the schema.Configuration used with this BridgeBuilder. +func (b *BridgeBuilder) WithConfig(config schema.Configuration) *BridgeBuilder { + b.config = config + + return b +} + +// WithProviders sets the Providers used with this BridgeBuilder. +func (b *BridgeBuilder) WithProviders(providers Providers) *BridgeBuilder { + b.providers = providers + + return b +} + +// WithPreMiddlewares sets the Middleware's used with this BridgeBuilder which are applied before the actual Bridge. +func (b *BridgeBuilder) WithPreMiddlewares(middlewares ...Middleware) *BridgeBuilder { + b.preMiddlewares = middlewares + + return b +} + +// WithPostMiddlewares sets the AutheliaMiddleware's used with this BridgeBuilder which are applied after the actual +// Bridge. +func (b *BridgeBuilder) WithPostMiddlewares(middlewares ...AutheliaMiddleware) *BridgeBuilder { + b.postMiddlewares = middlewares + + return b +} + +// Build and return the Bridge configured by this BridgeBuilder. +func (b *BridgeBuilder) Build() Bridge { + return func(next RequestHandler) fasthttp.RequestHandler { + for i := len(b.postMiddlewares) - 1; i >= 0; i-- { + next = b.postMiddlewares[i](next) + } + + bridge := func(requestCtx *fasthttp.RequestCtx) { + next(NewAutheliaCtx(requestCtx, b.config, b.providers)) + } + + for i := len(b.preMiddlewares) - 1; i >= 0; i-- { + bridge = b.preMiddlewares[i](bridge) + } + + return bridge + } +} diff --git a/internal/middlewares/cors_test.go b/internal/middlewares/cors_test.go index 9e42c103d..0273eb69d 100644 --- a/internal/middlewares/cors_test.go +++ b/internal/middlewares/cors_test.go @@ -562,13 +562,13 @@ func Test_CORSMiddleware_AsMiddleware(t *testing.T) { ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET") - autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{}) + middleware := NewBridgeBuilder(schema.Configuration{}, Providers{}).Build() cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS") policy := cors.Build() - route := policy.Middleware(autheliaMiddleware(testNilHandler)) + route := policy.Middleware(middleware(testNilHandler)) route(ctx) diff --git a/internal/middlewares/strip_path.go b/internal/middlewares/strip_path.go index c3934a806..7ebfc3549 100644 --- a/internal/middlewares/strip_path.go +++ b/internal/middlewares/strip_path.go @@ -7,7 +7,7 @@ import ( ) // StripPath strips the first level of a path. -func StripPath(path string) (middleware StandardMiddleware) { +func StripPath(path string) (middleware Middleware) { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { uri := ctx.RequestURI() diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 4e4d2bf88..f485e0c3f 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -45,14 +45,23 @@ type Providers struct { // RequestHandler represents an Authelia request handler. type RequestHandler = func(*AutheliaCtx) -// Middleware represent an Authelia middleware. -type Middleware = func(RequestHandler) RequestHandler +// AutheliaMiddleware represent an Authelia middleware. +type AutheliaMiddleware = func(next RequestHandler) RequestHandler -// StandardMiddleware represents a fasthttp middleware. -type StandardMiddleware = func(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) +// Middleware represents a fasthttp middleware. +type Middleware = func(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) -// RequestHandlerBridge bridge a AutheliaCtx handle to a RequestHandler handler. -type RequestHandlerBridge = func(RequestHandler) fasthttp.RequestHandler +// Bridge represents the func signature that returns a fasthttp.RequestHandler given a RequestHandler allowing it to +// bridge between the two handlers. +type Bridge = func(RequestHandler) fasthttp.RequestHandler + +// BridgeBuilder is used to build a Bridge. +type BridgeBuilder struct { + config schema.Configuration + providers Providers + preMiddlewares []Middleware + postMiddlewares []AutheliaMiddleware +} // IdentityVerificationStartArgs represent the arguments used to customize the starting phase // of the identity verification process. diff --git a/internal/mocks/authelia_ctx.go b/internal/mocks/authelia_ctx.go index 535e2491e..735f86855 100644 --- a/internal/mocks/authelia_ctx.go +++ b/internal/mocks/authelia_ctx.go @@ -119,8 +119,8 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx { // Set a cookie to identify this client throughout the test. // request.Request.Header.SetCookie("authelia_session", "client_cookie"). - autheliaCtx, _ := middlewares.NewAutheliaCtx(request, configuration, providers) - mockAuthelia.Ctx = autheliaCtx + ctx := middlewares.NewAutheliaCtx(request, configuration, providers) + mockAuthelia.Ctx = ctx logger, hook := test.NewNullLogger() mockAuthelia.Hook = hook diff --git a/internal/server/handlers.go b/internal/server/handlers.go index acc333280..83fb2bf01 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -100,7 +100,8 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa handlerPublicHTML := newPublicHTMLEmbeddedHandler() handlerLocales := newLocalesEmbeddedHandler() - middleware := middlewares.AutheliaMiddleware(config, providers, middlewares.SecurityHeaders) + middleware := middlewares.NewBridgeBuilder(config, providers). + WithPreMiddlewares(middlewares.SecurityHeaders).Build() policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). WithAllowedMethods("OPTIONS", "GET"). @@ -134,17 +135,21 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa r.GET("/api/"+file, handlerPublicHTML) } - middlewareAPI := middlewares.AutheliaMiddleware( - config, providers, - middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone, - ) + middlewareAPI := middlewares.NewBridgeBuilder(config, providers). + WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). + Build() + + middleware1FA := middlewares.NewBridgeBuilder(config, providers). + WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). + WithPostMiddlewares(middlewares.Require1FA). + Build() r.GET("/api/health", middlewareAPI(handlers.HealthGET)) r.GET("/api/state", middlewareAPI(handlers.StateGET)) - r.GET("/api/configuration", middlewareAPI(middlewares.Require1FA(handlers.ConfigurationGET))) + r.GET("/api/configuration", middleware1FA(handlers.ConfigurationGET)) - r.GET("/api/configuration/password-policy", middlewareAPI(handlers.PasswordPolicyConfigurationGet)) + r.GET("/api/configuration/password-policy", middlewareAPI(handlers.PasswordPolicyConfigurationGET)) r.GET("/api/verify", middlewareAPI(handlers.VerifyGET(config.AuthenticationBackend))) r.HEAD("/api/verify", middlewareAPI(handlers.VerifyGET(config.AuthenticationBackend))) @@ -166,26 +171,26 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa } // Information about the user. - r.GET("/api/user/info", middlewareAPI(middlewares.Require1FA(handlers.UserInfoGET))) - r.POST("/api/user/info", middlewareAPI(middlewares.Require1FA(handlers.UserInfoPOST))) - r.POST("/api/user/info/2fa_method", middlewareAPI(middlewares.Require1FA(handlers.MethodPreferencePOST))) + r.GET("/api/user/info", middleware1FA(handlers.UserInfoGET)) + r.POST("/api/user/info", middleware1FA(handlers.UserInfoPOST)) + r.POST("/api/user/info/2fa_method", middleware1FA(handlers.MethodPreferencePOST)) if !config.TOTP.Disable { // TOTP related endpoints. - r.GET("/api/user/info/totp", middlewareAPI(middlewares.Require1FA(handlers.UserTOTPInfoGET))) - r.POST("/api/secondfactor/totp/identity/start", middlewareAPI(middlewares.Require1FA(handlers.TOTPIdentityStart))) - r.POST("/api/secondfactor/totp/identity/finish", middlewareAPI(middlewares.Require1FA(handlers.TOTPIdentityFinish))) - r.POST("/api/secondfactor/totp", middlewareAPI(middlewares.Require1FA(handlers.TimeBasedOneTimePasswordPOST))) + r.GET("/api/user/info/totp", middleware1FA(handlers.UserTOTPInfoGET)) + r.POST("/api/secondfactor/totp/identity/start", middleware1FA(handlers.TOTPIdentityStart)) + r.POST("/api/secondfactor/totp/identity/finish", middleware1FA(handlers.TOTPIdentityFinish)) + r.POST("/api/secondfactor/totp", middleware1FA(handlers.TimeBasedOneTimePasswordPOST)) } if !config.Webauthn.Disable { // Webauthn Endpoints. - r.POST("/api/secondfactor/webauthn/identity/start", middlewareAPI(middlewares.Require1FA(handlers.WebauthnIdentityStart))) - r.POST("/api/secondfactor/webauthn/identity/finish", middlewareAPI(middlewares.Require1FA(handlers.WebauthnIdentityFinish))) - r.POST("/api/secondfactor/webauthn/attestation", middlewareAPI(middlewares.Require1FA(handlers.WebauthnAttestationPOST))) + r.POST("/api/secondfactor/webauthn/identity/start", middleware1FA(handlers.WebauthnIdentityStart)) + r.POST("/api/secondfactor/webauthn/identity/finish", middleware1FA(handlers.WebauthnIdentityFinish)) + r.POST("/api/secondfactor/webauthn/attestation", middleware1FA(handlers.WebauthnAttestationPOST)) - r.GET("/api/secondfactor/webauthn/assertion", middlewareAPI(middlewares.Require1FA(handlers.WebauthnAssertionGET))) - r.POST("/api/secondfactor/webauthn/assertion", middlewareAPI(middlewares.Require1FA(handlers.WebauthnAssertionPOST))) + r.GET("/api/secondfactor/webauthn/assertion", middleware1FA(handlers.WebauthnAssertionGET)) + r.POST("/api/secondfactor/webauthn/assertion", middleware1FA(handlers.WebauthnAssertionPOST)) } // Configure DUO api endpoint only if configuration exists. @@ -203,9 +208,9 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa config.DuoAPI.Hostname, "")) } - r.GET("/api/secondfactor/duo_devices", middlewareAPI(middlewares.Require1FA(handlers.DuoDevicesGET(duoAPI)))) - r.POST("/api/secondfactor/duo", middlewareAPI(middlewares.Require1FA(handlers.DuoPOST(duoAPI)))) - r.POST("/api/secondfactor/duo_device", middlewareAPI(middlewares.Require1FA(handlers.DuoDevicePOST))) + r.GET("/api/secondfactor/duo_devices", middleware1FA(handlers.DuoDevicesGET(duoAPI))) + r.POST("/api/secondfactor/duo", middleware1FA(handlers.DuoPOST(duoAPI))) + r.POST("/api/secondfactor/duo_device", middleware1FA(handlers.DuoDevicePOST)) } if config.Server.EnablePprof { @@ -217,23 +222,27 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa } if providers.OpenIDConnect.Fosite != nil { - r.GET("/api/oidc/consent", middlewareAPI(handlers.OpenIDConnectConsentGET)) - r.POST("/api/oidc/consent", middlewareAPI(handlers.OpenIDConnectConsentPOST)) + middlewareOIDC := middlewares.NewBridgeBuilder(config, providers).WithPreMiddlewares( + middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNone, middlewares.SecurityHeadersNoStore, + ).Build() + + r.GET("/api/oidc/consent", middlewareOIDC(handlers.OpenIDConnectConsentGET)) + r.POST("/api/oidc/consent", middlewareOIDC(handlers.OpenIDConnectConsentPOST)) allowedOrigins := utils.StringSliceFromURLs(config.IdentityProviders.OIDC.CORS.AllowedOrigins) r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS) - r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(middlewareAPI(handlers.OpenIDConnectConfigurationWellKnownGET))) + r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(middlewareOIDC(handlers.OpenIDConnectConfigurationWellKnownGET))) r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS) - r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(middlewareAPI(handlers.OAuthAuthorizationServerWellKnownGET))) + r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(middlewareOIDC(handlers.OAuthAuthorizationServerWellKnownGET))) r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS) r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET))) // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS) - r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(middlewareAPI(handlers.JSONWebKeySetGET))) + r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(middlewareOIDC(handlers.JSONWebKeySetGET))) policyCORSAuthorization := middlewares.NewCORSPolicyBuilder(). WithAllowedMethods("OPTIONS", "GET"). @@ -242,11 +251,11 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa Build() r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS) - r.GET(oidc.AuthorizationPath, middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) + r.GET(oidc.AuthorizationPath, middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint. r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS) - r.GET("/api/oidc/authorize", middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) + r.GET("/api/oidc/authorize", middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) policyCORSToken := middlewares.NewCORSPolicyBuilder(). WithAllowCredentials(true). @@ -256,7 +265,7 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa Build() r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS) - r.POST(oidc.TokenPath, policyCORSToken.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST)))) + r.POST(oidc.TokenPath, policyCORSToken.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST)))) policyCORSUserinfo := middlewares.NewCORSPolicyBuilder(). WithAllowCredentials(true). @@ -266,8 +275,8 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa Build() r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS) - r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) - r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) + r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) + r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) policyCORSIntrospection := middlewares.NewCORSPolicyBuilder(). WithAllowCredentials(true). @@ -277,11 +286,11 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa Build() r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS) - r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) + r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS) - r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) + r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) policyCORSRevocation := middlewares.NewCORSPolicyBuilder(). WithAllowCredentials(true). @@ -291,11 +300,11 @@ func getHandler(config schema.Configuration, providers middlewares.Providers) fa Build() r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS) - r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) + r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS) - r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(middlewareAPI(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) + r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(middlewareOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) } r.NotFound = handlerNotFound(middleware(serveIndexHandler))