From 4ebd8fdf4e9fb0eb20684197f39929304fcb74b7 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 7 Apr 2022 10:58:51 +1000 Subject: [PATCH] feat(oidc): provide cors config including options handlers (#3005) This adjusts the CORS headers appropriately for OpenID Connect. This includes responding to OPTIONS requests appropriately. Currently this is only configured to operate when the Origin scheme is HTTPS; but can easily be expanded in the future to include additional Origins. --- config.template.yml | 20 + docs/configuration/identity-providers/oidc.md | 136 +++- internal/configuration/config.template.yml | 20 + internal/configuration/provider_test.go | 18 + .../schema/identity_providers.go | 15 +- .../test_resources/config_oidc.yml | 133 ++++ internal/configuration/validator/const.go | 12 +- .../validator/identity_providers.go | 60 +- .../validator/identity_providers_test.go | 89 ++- internal/handlers/const.go | 10 - .../{handler_oidc_jwks.go => handler_jwks.go} | 3 +- ...tion.go => handler_oauth_introspection.go} | 5 +- ...ocation.go => handler_oauth_revocation.go} | 5 +- .../handlers/handler_oidc_authorization.go | 5 +- internal/handlers/handler_oidc_consent.go | 6 +- internal/handlers/handler_oidc_token.go | 5 +- internal/handlers/handler_oidc_userinfo.go | 7 +- internal/handlers/handler_oidc_wellknown.go | 16 +- internal/handlers/oidc_register.go | 37 -- internal/middlewares/const.go | 22 +- internal/middlewares/cors.go | 344 +++++++++- internal/middlewares/cors_test.go | 594 +++++++++++++++++- internal/oidc/const.go | 23 +- internal/oidc/provider_test.go | 4 +- internal/server/const.go | 13 +- internal/server/error_handler.go | 28 - internal/server/handler_notfound.go | 25 - internal/server/handlers.go | 56 ++ internal/server/options_handler.go | 11 - internal/server/server.go | 148 ++++- internal/server/template.go | 6 +- internal/utils/strings.go | 61 ++ internal/utils/strings_test.go | 46 ++ 33 files changed, 1729 insertions(+), 254 deletions(-) create mode 100644 internal/configuration/test_resources/config_oidc.yml rename internal/handlers/{handler_oidc_jwks.go => handler_jwks.go} (67%) rename internal/handlers/{handler_oidc_introspection.go => handler_oauth_introspection.go} (80%) rename internal/handlers/{handler_oidc_revocation.go => handler_oauth_revocation.go} (64%) delete mode 100644 internal/handlers/oidc_register.go delete mode 100644 internal/server/error_handler.go delete mode 100644 internal/server/handler_notfound.go create mode 100644 internal/server/handlers.go delete mode 100644 internal/server/options_handler.go diff --git a/config.template.yml b/config.template.yml index db21d7af5..8ea8ae488 100644 --- a/config.template.yml +++ b/config.template.yml @@ -767,6 +767,26 @@ notifier: ## for security reasons. # enforce_pkce: public_clients_only + ## Cross-Origin Resource Sharing (CORS) settings. + # cors: + ## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on. + # endpoints: + # - authorization + # - token + # - revocation + # - introspection + # - userinfo + + ## List of allowed origins. + ## Any origin with https is permitted unless this option is configured or the + ## allowed_origins_from_client_redirect_uris option is enabled. + # allowed_origins: + # - https://example.com + + ## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, + ## provided they have the scheme http or https and do not have the hostname of localhost. + # allowed_origins_from_client_redirect_uris: false + ## Clients is a list of known clients and their configuration. # clients: # - diff --git a/docs/configuration/identity-providers/oidc.md b/docs/configuration/identity-providers/oidc.md index 1c1dc0278..f2ffd2766 100644 --- a/docs/configuration/identity-providers/oidc.md +++ b/docs/configuration/identity-providers/oidc.md @@ -35,6 +35,15 @@ identity_providers: refresh_token_lifespan: 90m enable_client_debug_messages: false enforce_pkce: public_clients_only + cors: + endpoints: + - authorization + - token + - revocation + - introspection + allowed_origins: + - https://example.com + allowed_origins_from_client_redirect_uris: false clients: - id: myapp description: My Application @@ -218,6 +227,79 @@ Allows PKCE `plain` challenges when set to `true`. ***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead. +### cors + +Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows +you to configure the optional parts. We reply with CORS headers when the request includes the Origin header. + +##### endpoints +
+type: list(string) +{: .label .label-config .label-purple } +default: empty +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-green } +
+ +A list of endpoints to configure with cross-origin resource sharing headers. It is recommended that the `userinfo` +option is at least in this list. The potential endpoints which this can be enabled on are as follows: + +* authorization +* token +* revocation +* introspection +* userinfo + +#### allowed_origins +
+type: list(string) +{: .label .label-config .label-purple } +default: empty +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-green } +
+ +A list of permitted origins. + +Any origin with https is permitted unless this option is configured or the allowed_origins_from_client_redirect_uris +option is enabled. This means you must configure this option manually if you want http endpoints to be permitted to +make cross-origin requests to the OpenID Connect endpoints, however this is not recommended. + +Origins must only have the scheme, hostname and port, they may not have a trailing slash or path. + +In addition to an Origin URI, you may specify the wildcard origin in the allowed_origins. It MUST be specified by itself +and the allowed_origins_from_client_redirect_uris MUST NOT be enabled. The wildcard origin is denoted as `*`. Examples: + +```yaml +identity_providers: + oidc: + cors: + allowed_origins: "*" +``` + +```yaml +identity_providers: + oidc: + cors: + allowed_origins: + - "*" +``` + +#### allowed_origins_from_client_redirect_uris +
+type: boolean +{: .label .label-config .label-purple } +default: false +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-green } +
+ +Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, provided they +have the scheme http or https and do not have the hostname of localhost. + ### clients A list of clients to configure. The options for each client are described below. @@ -487,22 +569,46 @@ Below is a list of the potential values we place in the claim and their meaning: ## Endpoint Implementations -This is a table of the endpoints we currently support and their paths. This can be requrired information for some RP's, -particularly those that don't use [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html). The paths are -appended to the end of the primary URL used to access Authelia. For example in the Discovery example provided you access -Authelia via https://auth.example.com, the discovery URL is https://auth.example.com/.well-known/openid-configuration. +The following section documents the endpoints we implement and their respective paths. This information can traditionally +be discovered by relying parties that utilize [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html), +however this information may be useful for clients which do not implement this. -| Endpoint | Path | -|:-------------:|:---------------------------------------------:| -| Discovery | [root]/.well-known/openid-configuration | -| Metadata | [root]/.well-known/oauth-authorization-server | -| JWKS | [root]/api/oidc/jwks | -| Authorization | [root]/api/oidc/authorization | -| Token | [root]/api/oidc/token | -| Introspection | [root]/api/oidc/introspection | -| Revocation | [root]/api/oidc/revocation | -| Userinfo | [root]/api/oidc/userinfo | +The endpoints can be discovered easily by visiting the Discovery and Metadata endpoints. It is recommended regardless +of your version of Authelia that you utilize this version as it will always produce the correct endpoint URLs. The paths +for the Discovery/Metadata endpoints are part of IANA's well known registration but are also documented in a table below. + +These tables document the endpoints we currently support and their paths in the most recent version of Authelia. The paths +are appended to the end of the primary URL used to access Authelia. The tables use the url https://auth.example.com as +an example of the Authelia root URL which is also the OpenID Connect issuer. + +### Well Known Discovery Endpoints + +These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP. + +| Endpoint | Path | +|:-------------:|:---------------------------------------------------------------:| +| Discovery | https://auth.example.com/.well-known/openid-configuration | +| Metadata | https://auth.example.com/.well-known/oauth-authorization-server | + + +### Discoverable Endpoints + +These endpoints implement OpenID Connect elements. + +| Endpoint | Path | Discovery Attribute | +|:---------------:|:-----------------------------------------------:|:----------------------:| +| JWKS | https://auth.example.com/jwks.json | jwks_uri | +| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint | +| [Token] | https://auth.example.com/api/oidc/token | token_endpoint | +| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint | +| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint | +| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint | [OpenID Connect]: https://openid.net/connect/ [token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration -[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176 \ No newline at end of file +[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint +[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint +[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662 +[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009 +[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176 diff --git a/internal/configuration/config.template.yml b/internal/configuration/config.template.yml index db21d7af5..8ea8ae488 100644 --- a/internal/configuration/config.template.yml +++ b/internal/configuration/config.template.yml @@ -767,6 +767,26 @@ notifier: ## for security reasons. # enforce_pkce: public_clients_only + ## Cross-Origin Resource Sharing (CORS) settings. + # cors: + ## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on. + # endpoints: + # - authorization + # - token + # - revocation + # - introspection + # - userinfo + + ## List of allowed origins. + ## Any origin with https is permitted unless this option is configured or the + ## allowed_origins_from_client_redirect_uris option is enabled. + # allowed_origins: + # - https://example.com + + ## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, + ## provided they have the scheme http or https and do not have the hostname of localhost. + # allowed_origins_from_client_redirect_uris: false + ## Clients is a list of known clients and their configuration. # clients: # - diff --git a/internal/configuration/provider_test.go b/internal/configuration/provider_test.go index 52cfecfce..68c0a0b02 100644 --- a/internal/configuration/provider_test.go +++ b/internal/configuration/provider_test.go @@ -201,6 +201,24 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) { assert.Equal(t, "example_secret value", config.Storage.EncryptionKey) } +func TestShouldLoadURLList(t *testing.T) { + testReset() + + val := schema.NewStructValidator() + keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) + + assert.NoError(t, err) + + validator.ValidateKeys(keys, DefaultEnvPrefix, val) + + assert.Len(t, val.Errors(), 0) + assert.Len(t, val.Warnings(), 0) + + require.Len(t, config.IdentityProviders.OIDC.CORS.AllowedOrigins, 2) + assert.Equal(t, "https://google.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[0].String()) + assert.Equal(t, "https://example.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[1].String()) +} + func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) { testReset() diff --git a/internal/configuration/schema/identity_providers.go b/internal/configuration/schema/identity_providers.go index 4ceccabb9..3828172c1 100644 --- a/internal/configuration/schema/identity_providers.go +++ b/internal/configuration/schema/identity_providers.go @@ -1,6 +1,9 @@ package schema -import "time" +import ( + "net/url" + "time" +) // IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia. type IdentityProvidersConfiguration struct { @@ -24,9 +27,19 @@ type OpenIDConnectConfiguration struct { EnforcePKCE string `koanf:"enforce_pkce"` EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"` + CORS OpenIDConnectCORSConfiguration `koanf:"cors"` + Clients []OpenIDConnectClientConfiguration `koanf:"clients"` } +// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config. +type OpenIDConnectCORSConfiguration struct { + Endpoints []string `koanf:"endpoints"` + AllowedOrigins []url.URL `koanf:"allowed_origins"` + + AllowedOriginsFromClientRedirectURIs bool `koanf:"allowed_origins_from_client_redirect_uris"` +} + // OpenIDConnectClientConfiguration configuration for an OpenID Connect client. type OpenIDConnectClientConfiguration struct { ID string `koanf:"id"` diff --git a/internal/configuration/test_resources/config_oidc.yml b/internal/configuration/test_resources/config_oidc.yml new file mode 100644 index 000000000..974df781e --- /dev/null +++ b/internal/configuration/test_resources/config_oidc.yml @@ -0,0 +1,133 @@ +--- +default_redirection_url: https://home.example.com:8080/ + +server: + host: 127.0.0.1 + port: 9091 + +log: + level: debug + +totp: + issuer: authelia.com + +duo_api: + hostname: api-123456789.example.com + integration_key: ABCDEF + +authentication_backend: + ldap: + url: ldap://127.0.0.1 + base_dn: dc=example,dc=com + username_attribute: uid + additional_users_dn: ou=users + users_filter: (&({username_attribute}={input})(objectCategory=person)(objectClass=user)) + additional_groups_dn: ou=groups + groups_filter: (&(member={dn})(objectClass=groupOfNames)) + group_name_attribute: cn + mail_attribute: mail + user: cn=admin,dc=example,dc=com + +access_control: + default_policy: deny + + rules: + # Rules applied to everyone + - domain: public.example.com + policy: bypass + + - domain: secure.example.com + policy: one_factor + # Network based rule, if not provided any network matches. + networks: + - 192.168.1.0/24 + - domain: secure.example.com + policy: two_factor + + - domain: [singlefactor.example.com, onefactor.example.com] + policy: one_factor + + # Rules applied to 'admins' group + - domain: "mx2.mail.example.com" + subject: "group:admins" + policy: deny + - domain: "*.example.com" + subject: "group:admins" + policy: two_factor + + # Rules applied to 'dev' group + - domain: dev.example.com + resources: + - "^/groups/dev/.*$" + subject: "group:dev" + policy: two_factor + + # Rules applied to user 'john' + - domain: dev.example.com + resources: + - "^/users/john/.*$" + subject: "user:john" + policy: two_factor + + # Rules applied to 'dev' group and user 'john' + - domain: dev.example.com + resources: + - "^/deny-all.*$" + subject: ["group:dev", "user:john"] + policy: deny + + # Rules applied to user 'harry' + - domain: dev.example.com + resources: + - "^/users/harry/.*$" + subject: "user:harry" + policy: two_factor + + # Rules applied to user 'bob' + - domain: "*.mail.example.com" + subject: "user:bob" + policy: two_factor + - domain: "dev.example.com" + resources: + - "^/users/bob/.*$" + subject: "user:bob" + policy: two_factor + +session: + name: authelia_session + expiration: 3600000 # 1 hour + inactivity: 300000 # 5 minutes + domain: example.com + redis: + host: 127.0.0.1 + port: 6379 + high_availability: + sentinel_name: test + +regulation: + max_retries: 3 + find_time: 120 + ban_time: 300 + +storage: + mysql: + host: 127.0.0.1 + port: 3306 + database: authelia + username: authelia + +notifier: + smtp: + username: test + host: 127.0.0.1 + port: 1025 + sender: admin@example.com + disable_require_tls: true + +identity_providers: + oidc: + cors: + allowed_origins: + - https://google.com + - https://example.com +... diff --git a/internal/configuration/validator/const.go b/internal/configuration/validator/const.go index 21a470609..d597badfc 100644 --- a/internal/configuration/validator/const.go +++ b/internal/configuration/validator/const.go @@ -123,11 +123,15 @@ const ( const ( errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " + "more clients configured" - errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required" - + errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required" errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " + "'public_clients_only' or 'always', but it is configured as '%s'" + errFmtOIDCCORSInvalidOrigin = "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value '%s' as it has a %s: origins must only be scheme, hostname, and an optional port" + errFmtOIDCCORSInvalidOriginWildcard = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself" + errFmtOIDCCORSInvalidOriginWildcardWithClients = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled" + errFmtOIDCCORSInvalidEndpoint = "identity_providers: oidc: cors: option 'endpoints' contains an invalid value '%s': must be one of '%s'" + errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" + "id's must be unique" errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " + @@ -275,6 +279,7 @@ var validOIDCScopes = []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProf var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"} var validOIDCResponseModes = []string{"form_post", "query", "fragment"} var validOIDCUserinfoAlgorithms = []string{"none", "RS256"} +var validOIDCCORSEndpoints = []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint} var reKeyReplacer = regexp.MustCompile(`\[\d+]`) @@ -471,6 +476,9 @@ var ValidKeys = []string{ "identity_providers.oidc.enable_pkce_plain_challenge", "identity_providers.oidc.enable_client_debug_messages", "identity_providers.oidc.minimum_parameter_entropy", + "identity_providers.oidc.cors.endpoints", + "identity_providers.oidc.cors.allowed_origins", + "identity_providers.oidc.cors.enable_origins_from_clients", "identity_providers.oidc.clients", "identity_providers.oidc.clients[].id", "identity_providers.oidc.clients[].description", diff --git a/internal/configuration/validator/identity_providers.go b/internal/configuration/validator/identity_providers.go index 7f5fc8086..f1245bd7b 100644 --- a/internal/configuration/validator/identity_providers.go +++ b/internal/configuration/validator/identity_providers.go @@ -49,6 +49,7 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE)) } + validateOIDCOptionsCORS(config, validator) validateOIDCClients(config, validator) if len(config.Clients) == 0 { @@ -57,6 +58,64 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S } } +func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { + validateOIDCOptionsCORSAllowedOrigins(config, validator) + + if config.CORS.AllowedOriginsFromClientRedirectURIs { + validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config) + } + + validateOIDCOptionsCORSEndpoints(config, validator) +} + +func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { + for _, origin := range config.CORS.AllowedOrigins { + if origin.String() == "*" { + if len(config.CORS.AllowedOrigins) != 1 { + validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard)) + } + + if config.CORS.AllowedOriginsFromClientRedirectURIs { + validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients)) + } + + continue + } + + if origin.Path != "" { + validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path")) + } + + if origin.RawQuery != "" { + validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string")) + } + } +} + +func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) { + for _, client := range config.Clients { + for _, redirectURI := range client.RedirectURIs { + uri, err := url.Parse(redirectURI) + if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" { + continue + } + + origin := utils.OriginFromURL(*uri) + + if !utils.IsURLInSlice(origin, config.CORS.AllowedOrigins) { + config.CORS.AllowedOrigins = append(config.CORS.AllowedOrigins, origin) + } + } + } +} + +func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { + for _, endpoint := range config.CORS.Endpoints { + if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) { + validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '"))) + } + } +} func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { invalidID, duplicateIDs := false, false @@ -97,7 +156,6 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s validateOIDCClientResponseTypes(c, config, validator) validateOIDCClientResponseModes(c, config, validator) validateOIDDClientUserinfoAlgorithm(c, config, validator) - validateOIDCClientRedirectURIs(client, validator) } diff --git a/internal/configuration/validator/identity_providers_test.go b/internal/configuration/validator/identity_providers_test.go index 83a2af348..df1a191dd 100644 --- a/internal/configuration/validator/identity_providers_test.go +++ b/internal/configuration/validator/identity_providers_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/utils" ) func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) { @@ -29,6 +31,54 @@ func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) { assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured) } +func TestShouldNotRaiseErrorWhenCORSEndpointsValid(t *testing.T) { + validator := schema.NewStructValidator() + config := &schema.IdentityProvidersConfiguration{ + OIDC: &schema.OpenIDConnectConfiguration{ + HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN", + IssuerPrivateKey: "key-material", + CORS: schema.OpenIDConnectCORSConfiguration{ + Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint}, + }, + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "example", + Secret: "example", + }, + }, + }, + } + + ValidateIdentityProviders(config, validator) + + assert.Len(t, validator.Errors(), 0) +} + +func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) { + validator := schema.NewStructValidator() + config := &schema.IdentityProvidersConfiguration{ + OIDC: &schema.OpenIDConnectConfiguration{ + HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN", + IssuerPrivateKey: "key-material", + CORS: schema.OpenIDConnectCORSConfiguration{ + Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint, "invalid_endpoint"}, + }, + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "example", + Secret: "example", + }, + }, + }, + } + + ValidateIdentityProviders(config, validator) + + require.Len(t, validator.Errors(), 1) + + assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'") +} + func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) { validator := schema.NewStructValidator() config := &schema.IdentityProvidersConfiguration{ @@ -47,7 +97,44 @@ func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) { assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured) } -func TestShouldRaiseErrorWhenOIDCServerIssuerPrivateKeyPathInvalid(t *testing.T) { +func TestShouldRaiseErrorWhenOIDCCORSOriginsHasInvalidValues(t *testing.T) { + validator := schema.NewStructValidator() + + config := &schema.IdentityProvidersConfiguration{ + OIDC: &schema.OpenIDConnectConfiguration{ + HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN", + IssuerPrivateKey: "key-material", + CORS: schema.OpenIDConnectCORSConfiguration{ + AllowedOrigins: utils.URLsFromStringSlice([]string{"https://example.com/", "https://site.example.com/subpath", "https://site.example.com?example=true", "*"}), + AllowedOriginsFromClientRedirectURIs: true, + }, + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "myclient", + Secret: "jk12nb3klqwmnelqkwenm", + Policy: "two_factor", + RedirectURIs: []string{"https://example.com/oauth2_callback", "https://localhost:566/callback", "http://an.example.com/callback", "file://a/file"}, + }, + }, + }, + } + + ValidateIdentityProviders(config, validator) + + require.Len(t, validator.Errors(), 6) + assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://example.com/' as it has a path: origins must only be scheme, hostname, and an optional port") + assert.EqualError(t, validator.Errors()[1], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com/subpath' as it has a path: origins must only be scheme, hostname, and an optional port") + assert.EqualError(t, validator.Errors()[2], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com?example=true' as it has a query string: origins must only be scheme, hostname, and an optional port") + assert.EqualError(t, validator.Errors()[3], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself") + assert.EqualError(t, validator.Errors()[4], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled") + assert.EqualError(t, validator.Errors()[5], "identity_providers: oidc: client 'myclient': option 'redirect_uris' has an invalid value: redirect uri 'file://a/file' must have a scheme of 'http' or 'https' but 'file' is configured") + + require.Len(t, config.OIDC.CORS.AllowedOrigins, 6) + assert.Equal(t, "*", config.OIDC.CORS.AllowedOrigins[3].String()) + assert.Equal(t, "https://example.com", config.OIDC.CORS.AllowedOrigins[4].String()) +} + +func TestShouldRaiseErrorWhenOIDCServerNoClients(t *testing.T) { validator := schema.NewStructValidator() config := &schema.IdentityProvidersConfiguration{ OIDC: &schema.OpenIDConnectConfiguration{ diff --git a/internal/handlers/const.go b/internal/handlers/const.go index 64c174034..a43fa75b9 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -72,16 +72,6 @@ const ( auth = "auth" ) -// OIDC constants. -const ( - pathLegacyOpenIDConnectAuthorization = "/api/oidc/authorize" - pathLegacyOpenIDConnectIntrospection = "/api/oidc/introspect" - pathLegacyOpenIDConnectRevocation = "/api/oidc/revoke" - - // Note: If you change this const you must also do so in the frontend at web/src/services/Api.ts. - pathOpenIDConnectConsent = "/api/oidc/consent" -) - const ( accept = "accept" reject = "reject" diff --git a/internal/handlers/handler_oidc_jwks.go b/internal/handlers/handler_jwks.go similarity index 67% rename from internal/handlers/handler_oidc_jwks.go rename to internal/handlers/handler_jwks.go index 37e926345..14f680711 100644 --- a/internal/handlers/handler_oidc_jwks.go +++ b/internal/handlers/handler_jwks.go @@ -6,7 +6,8 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -func oidcJWKs(ctx *middlewares.AutheliaCtx) { +// JSONWebKeySetGET returns the JSON Web Key Set. Used in OAuth 2.0 and OpenID Connect 1.0. +func JSONWebKeySetGET(ctx *middlewares.AutheliaCtx) { ctx.SetContentType("application/json") if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil { diff --git a/internal/handlers/handler_oidc_introspection.go b/internal/handlers/handler_oauth_introspection.go similarity index 80% rename from internal/handlers/handler_oidc_introspection.go rename to internal/handlers/handler_oauth_introspection.go index ddc898103..331ce201d 100644 --- a/internal/handlers/handler_oidc_introspection.go +++ b/internal/handlers/handler_oauth_introspection.go @@ -9,7 +9,10 @@ import ( "github.com/authelia/authelia/v4/internal/oidc" ) -func oidcIntrospection(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { +// OAuthIntrospectionPOST handles POST requests to the OAuth 2.0 Introspection endpoint. +// +// https://datatracker.ietf.org/doc/html/rfc7662 +func OAuthIntrospectionPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { var ( responder fosite.IntrospectionResponder err error diff --git a/internal/handlers/handler_oidc_revocation.go b/internal/handlers/handler_oauth_revocation.go similarity index 64% rename from internal/handlers/handler_oidc_revocation.go rename to internal/handlers/handler_oauth_revocation.go index 84b4700cf..1dad867bc 100644 --- a/internal/handlers/handler_oidc_revocation.go +++ b/internal/handlers/handler_oauth_revocation.go @@ -8,7 +8,10 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -func oidcRevocation(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { +// OAuthRevocationPOST handles POST requests to the OAuth 2.0 Revocation endpoint. +// +// https://datatracker.ietf.org/doc/html/rfc7009 +func OAuthRevocationPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { var err error if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil { diff --git a/internal/handlers/handler_oidc_authorization.go b/internal/handlers/handler_oidc_authorization.go index 88f407568..c5940410f 100644 --- a/internal/handlers/handler_oidc_authorization.go +++ b/internal/handlers/handler_oidc_authorization.go @@ -16,7 +16,10 @@ import ( "github.com/authelia/authelia/v4/internal/session" ) -func oidcAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) { +// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint. +// +// https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint +func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) { var ( requester fosite.AuthorizeRequester responder fosite.AuthorizeResponder diff --git a/internal/handlers/handler_oidc_consent.go b/internal/handlers/handler_oidc_consent.go index f07b52e47..b403ecb01 100644 --- a/internal/handlers/handler_oidc_consent.go +++ b/internal/handlers/handler_oidc_consent.go @@ -7,7 +7,8 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -func oidcConsent(ctx *middlewares.AutheliaCtx) { +// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect. +func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() if userSession.OIDCWorkflowSession == nil { @@ -39,7 +40,8 @@ func oidcConsent(ctx *middlewares.AutheliaCtx) { } } -func oidcConsentPOST(ctx *middlewares.AutheliaCtx) { +// OpenIDConnectConsentPOST handles consent responses for OpenID Connect. +func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() if userSession.OIDCWorkflowSession == nil { diff --git a/internal/handlers/handler_oidc_token.go b/internal/handlers/handler_oidc_token.go index 714fcb555..59a9a55eb 100644 --- a/internal/handlers/handler_oidc_token.go +++ b/internal/handlers/handler_oidc_token.go @@ -9,7 +9,10 @@ import ( "github.com/authelia/authelia/v4/internal/oidc" ) -func oidcToken(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { +// OpenIDConnectTokenPOST handles POST requests to the OpenID Connect 1.0 Token endpoint. +// +// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint +func OpenIDConnectTokenPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { var ( requester fosite.AccessRequester responder fosite.AccessResponder diff --git a/internal/handlers/handler_oidc_userinfo.go b/internal/handlers/handler_oidc_userinfo.go index 6cc2a90df..1a46ec39f 100644 --- a/internal/handlers/handler_oidc_userinfo.go +++ b/internal/handlers/handler_oidc_userinfo.go @@ -14,7 +14,10 @@ import ( "github.com/authelia/authelia/v4/internal/oidc" ) -func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { +// OpenIDConnectUserinfo handles GET/POST requests to the OpenID Connect 1.0 UserInfo endpoint. +// +// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { var ( tokenType fosite.TokenType requester fosite.AccessRequester @@ -97,7 +100,7 @@ func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *htt var jti uuid.UUID if jti, err = uuid.NewRandom(); err != nil { - ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JWT ID.")) + ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JTI.")) return } diff --git a/internal/handlers/handler_oidc_wellknown.go b/internal/handlers/handler_oidc_wellknown.go index 3a5196c23..0efd5387c 100644 --- a/internal/handlers/handler_oidc_wellknown.go +++ b/internal/handlers/handler_oidc_wellknown.go @@ -8,7 +8,13 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) { +// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the +// OpenID Connect Discovery 1.0 metadata. +// +// https://datatracker.ietf.org/doc/html/rfc5785 +// +// https://openid.net/specs/openid-connect-discovery-1_0.html +func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) { issuer, err := ctx.ExternalRootURL() if err != nil { ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) @@ -30,7 +36,13 @@ func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) { } } -func wellKnownOAuthAuthorizationServerGET(ctx *middlewares.AutheliaCtx) { +// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the +// OAuth 2.0 Authorization Server Metadata (RFC8414). +// +// https://datatracker.ietf.org/doc/html/rfc5785 +// +// https://datatracker.ietf.org/doc/html/rfc8414 +func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) { issuer, err := ctx.ExternalRootURL() if err != nil { ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) diff --git a/internal/handlers/oidc_register.go b/internal/handlers/oidc_register.go deleted file mode 100644 index 58646da7d..000000000 --- a/internal/handlers/oidc_register.go +++ /dev/null @@ -1,37 +0,0 @@ -package handlers - -import ( - "github.com/fasthttp/router" - - "github.com/authelia/authelia/v4/internal/middlewares" - "github.com/authelia/authelia/v4/internal/oidc" -) - -// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout. -func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) { - // TODO: Add OPTIONS handler. - router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET))) - router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET))) - - router.GET(pathOpenIDConnectConsent, middleware(oidcConsent)) - - router.POST(pathOpenIDConnectConsent, middleware(oidcConsentPOST)) - - router.GET(oidc.JWKsPath, middleware(oidcJWKs)) - - router.GET(oidc.AuthorizationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization))) - router.GET(pathLegacyOpenIDConnectAuthorization, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization))) - - // TODO: Add OPTIONS handler. - router.POST(oidc.TokenPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcToken))) - - router.POST(oidc.IntrospectionPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection))) - router.GET(pathLegacyOpenIDConnectIntrospection, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection))) - - router.GET(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo))) - router.POST(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo))) - - // TODO: Add OPTIONS handler. - router.POST(oidc.RevocationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation))) - router.POST(pathLegacyOpenIDConnectRevocation, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation))) -} diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index e83a25a36..b0475281d 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -7,18 +7,22 @@ import ( ) var ( + headerAccept = []byte(fasthttp.HeaderAccept) + headerContentLength = []byte(fasthttp.HeaderContentLength) + headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto) headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost) headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor) headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith) - headerAccept = []byte(fasthttp.HeaderAccept) headerXForwardedURI = []byte("X-Forwarded-URI") headerXOriginalURL = []byte("X-Original-URL") headerXForwardedMethod = []byte("X-Forwarded-Method") - headerVary = []byte(fasthttp.HeaderVary) - headerOrigin = []byte(fasthttp.HeaderOrigin) + headerVary = []byte(fasthttp.HeaderVary) + headerAllow = []byte(fasthttp.HeaderAllow) + headerOrigin = []byte(fasthttp.HeaderOrigin) + headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials) headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders) headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods) @@ -29,9 +33,13 @@ var ( ) var ( - headerValueFalse = []byte("false") - headerValueMaxAge = []byte("100") - headerValueVary = []byte("Accept-Encoding, Origin") + headerValueFalse = []byte("false") + headerValueTrue = []byte("true") + headerValueMaxAge = []byte("100") + headerValueVary = []byte("Accept-Encoding, Origin") + headerValueVaryWildcard = []byte("Accept-Encoding") + headerValueOriginWildcard = []byte("*") + headerValueZero = []byte("0") ) var ( @@ -40,6 +48,8 @@ var ( // UserValueKeyBaseURL is the User Value key where we store the Base URL. UserValueKeyBaseURL = []byte("base_url") + + headerSeparator = []byte(", ") ) const ( diff --git a/internal/middlewares/cors.go b/internal/middlewares/cors.go index a4152dcf1..7936e6f70 100644 --- a/internal/middlewares/cors.go +++ b/internal/middlewares/cors.go @@ -1,53 +1,347 @@ package middlewares import ( + "bytes" "net/url" + "strconv" "strings" "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/utils" ) -// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well -// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied -// to both Accept-Encoding and Origin. It grants the GET Request Method only. -func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler { - return func(ctx *AutheliaCtx) { - if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil { - corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin) +// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header +// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme +// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header +// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS +// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by +// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special * +// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets +// the Access-Control-Max-Age header to 100. +// +// These behaviours can be overridden by the With methods on the returned policy. +func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) { + return &CORSPolicyBuilder{ + enabled: true, + maxAge: 100, + } +} + +// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be +// configured. It aims to simplify CORS configurations. +type CORSPolicyBuilder struct { + enabled bool + varyOnly bool + varySet bool + methods []string + headers []string + origins []string + credentials bool + vary []string + maxAge int +} + +// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy. +func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) { + policy = &CORSPolicy{ + enabled: b.enabled, + varyOnly: b.varyOnly, + credentials: []byte(strconv.FormatBool(b.credentials)), + origins: b.buildOrigins(), + headers: b.buildHeaders(), + vary: b.buildVary(), + } + + if len(b.methods) != 0 { + policy.methods = []byte(strings.Join(b.methods, ", ")) + } + + if b.maxAge <= 0 { + policy.maxAge = headerValueMaxAge + } else { + policy.maxAge = []byte(strconv.Itoa(b.maxAge)) + } + + return policy +} + +func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) { + if len(b.origins) != 0 { + if len(b.origins) == 1 && b.origins[0] == "*" { + origins = append(origins, []byte(b.origins[0])) + } else { + for _, origin := range b.origins { + origins = append(origins, []byte(origin)) + } } + } + + return origins +} + +func (b CORSPolicyBuilder) buildHeaders() (headers []byte) { + if len(b.headers) != 0 { + h := b.headers + + if b.credentials { + if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) { + h = append(h, fasthttp.HeaderCookie) + } + + if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) { + h = append(h, fasthttp.HeaderAuthorization) + } + + if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) { + h = append(h, fasthttp.HeaderProxyAuthorization) + } + } + + headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...) + } + + return headers +} + +func (b CORSPolicyBuilder) buildVary() (vary []byte) { + if b.varySet { + if len(b.vary) != 0 { + vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...) + } + } else { + if len(b.origins) == 1 && b.origins[0] == "*" { + vary = headerValueVaryWildcard + } else { + vary = headerValueVary + } + } + + return vary +} + +// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this +// value will be true but this function can override the value. Setting it to false prevents the middleware from adding +// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS +// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP +// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the +// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method. +func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) { + b.enabled = enabled + + return b +} + +// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with +// that value. +func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) { + b.methods = methods + + return b +} + +// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these. +func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) { + b.origins = origins + + return b +} + +// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header. +func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) { + b.headers = headers + + return b +} + +// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header. +func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) { + b.credentials = allow + + return b +} + +// WithVary takes a list of header strings and alters the default Vary header. +func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) { + b.vary = headers + b.varySet = true + + return b +} + +// WithVaryOnly just adds the Vary header. +func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) { + b.varyOnly = varyOnly + + return b +} + +// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header. +func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) { + b.maxAge = age + + return b +} + +// CORSPolicy is a middleware that handles adding CORS headers. +type CORSPolicy struct { + enabled bool + varyOnly bool + methods []byte + headers []byte + origins [][]byte + credentials []byte + vary []byte + maxAge []byte +} + +// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204 +// without a body. This handler should generally not be used without using WithAllowedMethods. +func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) { + p.handleOPTIONS(ctx) + p.handle(ctx) +} + +// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204 +// without a body. This handler should generally not be used without using WithAllowedMethods. +func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) { + p.handleOPTIONS(ctx) +} + +// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder. +func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) { + return func(ctx *fasthttp.RequestCtx) { + p.handle(ctx) next(ctx) } } -func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) { - originURL, err := url.Parse(string(origin)) - if err != nil || originURL.Scheme != "https" { +func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) { + if !p.enabled { return } - resp.Header.SetBytesKV(headerVary, headerValueVary) - resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin) - resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse) - resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge) + p.handleVary(ctx) - if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil { - requestedHeaders := strings.Split(string(headers), ",") - allowHeaders := make([]string, len(requestedHeaders)) + if !p.varyOnly { + p.handleCORS(ctx) + } +} - for i, header := range requestedHeaders { - headerTrimmed := strings.Trim(header, " ") - if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) { - allowHeaders[i] = headerTrimmed +func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) { + ctx.Response.ResetBody() + + /* The OPTIONS method should not return a 204 as per the following specifications when read together: + + RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7): + A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in + the response. + + RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2): + A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational) + or 204 (No Content). + */ + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero) + + if len(p.methods) != 0 { + ctx.Response.Header.SetBytesKV(headerAllow, p.methods) + } +} + +func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) { + if len(p.vary) != 0 { + ctx.Response.Header.SetBytesKV(headerVary, p.vary) + } +} + +func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) { + var ( + originURL *url.URL + err error + ) + + origin := ctx.Request.Header.PeekBytes(headerOrigin) + + // Skip processing of any `https` scheme URL that has not expressly been configured. + if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) { + return + } + + var allowedOrigin []byte + + switch len(p.origins) { + case 0: + allowedOrigin = origin + default: + for i := 0; i < len(p.origins); i++ { + if bytes.Equal(p.origins[i], headerValueOriginWildcard) { + allowedOrigin = headerValueOriginWildcard + } else if bytes.Equal(p.origins[i], origin) { + allowedOrigin = origin } } - if len(allowHeaders) != 0 { - resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", "))) + if len(allowedOrigin) == 0 { + return } } - if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil { - resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods) + ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin) + + if len(p.credentials) != 0 { + ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials) + } + + if len(p.maxAge) != 0 { + ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge) + } + + p.handleAllowedHeaders(ctx) + p.handleAllowedMethods(ctx) +} + +func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) { + switch len(p.methods) { + case 0: + // TODO: It may be beneficial to be able to control this automatic behaviour. + if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil { + ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods) + } + default: + ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods) + } +} + +func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) { + switch len(p.headers) { + case 0: + // TODO: It may be beneficial to be able to control this automatic behaviour. + if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil { + requestedHeaders := strings.Split(string(headers), ",") + allowHeaders := make([]string, 0, len(requestedHeaders)) + + for i := 0; i < len(requestedHeaders); i++ { + headerTrimmed := strings.Trim(requestedHeaders[i], " ") + + if headerTrimmed == "*" { + continue + } + + if bytes.Equal(p.credentials, headerValueTrue) || + (!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) && + !strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) && + !strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) { + allowHeaders = append(allowHeaders, headerTrimmed) + } + } + + if len(allowHeaders) != 0 { + ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", "))) + } + } + default: + ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers) } } diff --git a/internal/middlewares/cors_test.go b/internal/middlewares/cors_test.go index f44106aee..9e42c103d 100644 --- a/internal/middlewares/cors_test.go +++ b/internal/middlewares/cors_test.go @@ -5,61 +5,587 @@ import ( "github.com/stretchr/testify/assert" "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/configuration/schema" ) -func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) { - req := fasthttp.AcquireRequest() - resp := fasthttp.Response{} +func TestNewCORSMiddleware(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Equal(t, 100, cors.maxAge) + assert.Equal(t, false, cors.credentials) + + assert.Nil(t, cors.methods) + assert.Nil(t, cors.origins) + assert.Nil(t, cors.headers) + assert.Nil(t, cors.vary) + assert.False(t, cors.varyOnly) + assert.False(t, cors.varySet) +} + +func TestCORSPolicyBuilder_WithEnabled(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.True(t, cors.enabled) + + cors.WithEnabled(false) + assert.False(t, cors.enabled) +} + +func TestCORSPolicyBuilder_WithVary(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Nil(t, cors.vary) + assert.False(t, cors.varyOnly) + assert.False(t, cors.varySet) + + cors.WithVary() + assert.Nil(t, cors.vary) + assert.False(t, cors.varyOnly) + assert.True(t, cors.varySet) + + cors.WithVary("Origin", "Example", "Test") + + assert.Equal(t, []string{"Origin", "Example", "Test"}, cors.vary) + assert.False(t, cors.varyOnly) + assert.True(t, cors.varySet) +} + +func TestCORSPolicyBuilder_WithAllowedMethods(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Nil(t, cors.methods) + + cors.WithAllowedMethods("GET") + + assert.Equal(t, []string{"GET"}, cors.methods) + + cors.WithAllowedMethods("POST", "PATCH") + + assert.Equal(t, []string{"POST", "PATCH"}, cors.methods) + + cors.WithAllowedMethods() + + assert.Nil(t, cors.methods) +} + +func TestCORSPolicyBuilder_WithAllowedOrigins(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Nil(t, cors.origins) + + cors.WithAllowedOrigins("https://google.com", "http://localhost") + + assert.Equal(t, []string{"https://google.com", "http://localhost"}, cors.origins) + + cors.WithAllowedOrigins() + + assert.Nil(t, cors.origins) +} + +func TestCORSPolicyBuilder_WithAllowedHeaders(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Nil(t, cors.headers) + + cors.WithAllowedHeaders("Example", "Another") + + assert.Equal(t, []string{"Example", "Another"}, cors.headers) + + cors.WithAllowedHeaders() + + assert.Nil(t, cors.headers) +} + +func TestCORSPolicyBuilder_WithAllowCredentials(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Equal(t, false, cors.credentials) + + cors.WithAllowCredentials(false) + + assert.Equal(t, false, cors.credentials) + + cors.WithAllowCredentials(true) + + assert.Equal(t, true, cors.credentials) +} + +func TestCORSPolicyBuilder_WithVaryOnly(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.False(t, cors.varyOnly) + + cors.WithVaryOnly(false) + + assert.False(t, cors.varyOnly) + + cors.WithVaryOnly(true) + + cors.WithVaryOnly(true) +} + +func TestCORSPolicyBuilder_WithMaxAge(t *testing.T) { + cors := NewCORSPolicyBuilder() + + assert.Equal(t, 100, cors.maxAge) + + cors.WithMaxAge(20) + + assert.Equal(t, 20, cors.maxAge) + + cors.WithMaxAge(0) + + assert.Equal(t, 0, cors.maxAge) +} + +func TestCORSPolicyBuilder_HandleOPTIONS(t *testing.T) { + ctx := newFastHTTPRequestCtx() origin := []byte("https://myapp.example.com") - req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) - corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + cors := NewCORSPolicyBuilder() + policy := cors.Build() - assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) - assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin)) - assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials)) - assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) - assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowedMethods("GET", "OPTIONS") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + policy = cors.Build() + policy.HandleOnlyOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithEnabled(false) + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_HandleOPTIONS_WithoutOrigin(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + + cors := NewCORSPolicyBuilder() + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + + cors.WithAllowedMethods("GET", "OPTIONS") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedOrigins(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors := NewCORSPolicyBuilder() + cors.WithAllowedOrigins("https://myapp.example.com") + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowedOrigins("https://anotherapp.example.com") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowedOrigins("*") + cors.WithAllowedMethods("GET", "OPTIONS") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_WithAllowedOrigins_DoesntOverrideVary(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors := NewCORSPolicyBuilder() + cors.WithVary("Accept-Encoding", "Origin", "Test") + cors.WithAllowedOrigins("*") + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin, Test"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_HandleOPTIONSWithVaryOnly(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors := NewCORSPolicyBuilder() + + cors.WithVaryOnly(true) + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowedMethods("GET", "OPTIONS") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedHeaders(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors := NewCORSPolicyBuilder() + + cors.WithAllowedHeaders("Example", "Test") + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowedMethods("GET", "OPTIONS") + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) + + ctx = newFastHTTPRequestCtx() + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors.WithAllowCredentials(true) + + policy = cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueTrue, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("Example, Test, Cookie, Authorization, Proxy-Authorization"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func TestCORSPolicyBuilder_HandleOPTIONS_ShouldNotAllowWildcardInRequestedHeaders(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "*") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + + cors := NewCORSPolicyBuilder() + + policy := cors.Build() + policy.HandleOPTIONS(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow)) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + + cors := NewCORSPolicyBuilder() + + policy := cors.Build() + policy.handle(ctx) + + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) } func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) { - req := fasthttp.AcquireRequest() - resp := fasthttp.Response{} + ctx := newFastHTTPRequestCtx() origin := []byte("https://myapp.example.com") - req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") - req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET") - corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + cors := NewCORSPolicyBuilder() - assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) - assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin)) - assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials)) - assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) - assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) - assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods)) + policy := cors.Build() + policy.handle(ctx) + + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) } func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) { - req := fasthttp.AcquireRequest() - - resp := fasthttp.Response{} + ctx := newFastHTTPRequestCtx() origin := []byte("http://myapp.example.com") - req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") - req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET") - corsApplyAutomaticAllowAllPolicy(req, &resp, origin) + cors := NewCORSPolicyBuilder().WithVary() - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) - assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) + policy := cors.Build() + policy.handle(ctx) + + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func Test_CORSMiddleware_AsMiddleware(t *testing.T) { + ctx := newFastHTTPRequestCtx() + + origin := []byte("https://myapp.example.com") + + ctx.Request.Header.SetBytesKV(headerOrigin, origin) + ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") + ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET") + + autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{}) + + cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS") + + policy := cors.Build() + + route := policy.Middleware(autheliaMiddleware(testNilHandler)) + + route(ctx) + + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary)) + assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin)) + assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials)) + assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge)) + assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders)) + assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods)) +} + +func testNilHandler(_ *AutheliaCtx) {} + +func newFastHTTPRequestCtx() (ctx *fasthttp.RequestCtx) { + return &fasthttp.RequestCtx{ + Request: fasthttp.Request{}, + Response: fasthttp.Response{}, + } } diff --git a/internal/oidc/const.go b/internal/oidc/const.go index 2d128aec5..72d395dd1 100644 --- a/internal/oidc/const.go +++ b/internal/oidc/const.go @@ -19,17 +19,28 @@ const ( ClaimEmailAlts = "alt_emails" ) +// Endpoints. +const ( + AuthorizationEndpoint = "authorization" + TokenEndpoint = "token" + UserinfoEndpoint = "userinfo" + IntrospectionEndpoint = "introspection" + RevocationEndpoint = "revocation" +) + // Paths. const ( WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration" WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server" + JWKsPath = "/jwks.json" - JWKsPath = "/api/oidc/jwks" - AuthorizationPath = "/api/oidc/authorization" - TokenPath = "/api/oidc/token" //nolint:gosec // This is not a hard coded credential, it's a path. - IntrospectionPath = "/api/oidc/introspection" - RevocationPath = "/api/oidc/revocation" - UserinfoPath = "/api/oidc/userinfo" + RootPath = "/api/oidc" + + AuthorizationPath = RootPath + "/" + AuthorizationEndpoint + TokenPath = RootPath + "/" + TokenEndpoint + UserinfoPath = RootPath + "/" + UserinfoEndpoint + IntrospectionPath = RootPath + "/" + IntrospectionEndpoint + RevocationPath = RootPath + "/" + RevocationEndpoint ) // Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176 diff --git a/internal/oidc/provider_test.go b/internal/oidc/provider_test.go index c8bfba7d1..a6f6e0cb7 100644 --- a/internal/oidc/provider_test.go +++ b/internal/oidc/provider_test.go @@ -87,7 +87,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com") assert.Equal(t, "https://example.com", disco.Issuer) - assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI) + assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI) assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint) assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint) assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint) @@ -173,7 +173,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig disco := provider.GetOAuth2WellKnownConfiguration("https://example.com") assert.Equal(t, "https://example.com", disco.Issuer) - assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI) + assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI) assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint) assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint) assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint) diff --git a/internal/server/const.go b/internal/server/const.go index a00d6b880..098a5bea8 100644 --- a/internal/server/const.go +++ b/internal/server/const.go @@ -37,16 +37,17 @@ var ( {name: "/api", prefix: "/api/"}, {name: "/.well-known", prefix: "/.well-known/"}, {name: "/static", prefix: "/static/"}, + {name: "/locales", prefix: "/locales/"}, } ) -const schemeHTTP = "http" -const schemeHTTPS = "https" - const ( - dev = "dev" - f = "false" - t = "true" + dev = "dev" + f = "false" + t = "true" + localhost = "localhost" + schemeHTTP = "http" + schemeHTTPS = "https" ) const healthCheckEnv = `# Written by Authelia Process diff --git a/internal/server/error_handler.go b/internal/server/error_handler.go deleted file mode 100644 index 36b93e5b2..000000000 --- a/internal/server/error_handler.go +++ /dev/null @@ -1,28 +0,0 @@ -package server - -import ( - "net" - - "github.com/valyala/fasthttp" - - "github.com/authelia/authelia/v4/internal/logging" -) - -// Replacement for the default error handler in fasthttp. -func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) { - logger := logging.Logger() - - if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { - // Note: Getting X-Forwarded-For or Request URI is impossible for ths error. - logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge) - ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge) - } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { - // TODO: Add X-Forwarded-For Check here. - logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout) - ctx.Error("request timeout", fasthttp.StatusRequestTimeout) - } else { - // TODO: Add X-Forwarded-For Check here. - logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest) - ctx.Error("error when parsing request", fasthttp.StatusBadRequest) - } -} diff --git a/internal/server/handler_notfound.go b/internal/server/handler_notfound.go deleted file mode 100644 index c8002daa4..000000000 --- a/internal/server/handler_notfound.go +++ /dev/null @@ -1,25 +0,0 @@ -package server - -import ( - "strings" - - "github.com/valyala/fasthttp" - - "github.com/authelia/authelia/v4/internal/handlers" -) - -func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - path := strings.ToLower(string(ctx.Path())) - - for i := 0; i < len(httpServerDirs); i++ { - if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) { - handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound) - - return - } - } - - next(ctx) - } -} diff --git a/internal/server/handlers.go b/internal/server/handlers.go new file mode 100644 index 000000000..1beb151ac --- /dev/null +++ b/internal/server/handlers.go @@ -0,0 +1,56 @@ +package server + +import ( + "net" + "strings" + + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/handlers" + "github.com/authelia/authelia/v4/internal/logging" +) + +// Replacement for the default error handler in fasthttp. +func handlerErrors(ctx *fasthttp.RequestCtx, err error) { + logger := logging.Logger() + + switch e := err.(type) { + case *fasthttp.ErrSmallBuffer: + logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge) + ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge) + case *net.OpError: + if e.Timeout() { + // TODO: Add X-Forwarded-For Check here. + logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout) + ctx.Error("request timeout", fasthttp.StatusRequestTimeout) + } else { + // TODO: Add X-Forwarded-For Check here. + logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest) + ctx.Error("error when parsing request", fasthttp.StatusBadRequest) + } + default: + // TODO: Add X-Forwarded-For Check here. + logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest) + ctx.Error("error when parsing request", fasthttp.StatusBadRequest) + } +} + +func handlerNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + path := strings.ToLower(string(ctx.Path())) + + for i := 0; i < len(httpServerDirs); i++ { + if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) { + handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound) + + return + } + } + + next(ctx) + } +} + +func handlerMethodNotAllowed(ctx *fasthttp.RequestCtx) { + handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed) +} diff --git a/internal/server/options_handler.go b/internal/server/options_handler.go deleted file mode 100644 index 7ebd15d3c..000000000 --- a/internal/server/options_handler.go +++ /dev/null @@ -1,11 +0,0 @@ -package server - -import ( - "github.com/valyala/fasthttp" - - "github.com/authelia/authelia/v4/internal/middlewares" -) - -func handleOPTIONS(ctx *middlewares.AutheliaCtx) { - ctx.SetStatusCode(fasthttp.StatusNoContent) -} diff --git a/internal/server/server.go b/internal/server/server.go index 9e40679af..985fd5b20 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -19,10 +19,12 @@ import ( "github.com/authelia/authelia/v4/internal/handlers" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/utils" ) +// TODO: move to its own file and rename configuration -> config. func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { - autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers) rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled) resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword) @@ -33,37 +35,49 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment) } - handlerPublicHTML := newPublicHTMLEmbeddedHandler() - handlerLocales := newLocalesEmbeddedHandler() - https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != "" serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) + handlerPublicHTML := newPublicHTMLEmbeddedHandler() + handlerLocales := newLocalesEmbeddedHandler() + + autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers) + + policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods("OPTIONS", "GET"). + WithAllowedOrigins("*"). + Build() + r := router.New() + + // Static Assets. r.GET("/", autheliaMiddleware(serveIndexHandler)) - r.OPTIONS("/", autheliaMiddleware(handleOPTIONS)) for _, f := range rootFiles { r.GET("/"+f, handlerPublicHTML) } - r.GET("/api/", autheliaMiddleware(serveSwaggerHandler)) - r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler)) - - for _, file := range swaggerFiles { - r.GET("/api/"+file, handlerPublicHTML) - } - r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML)) r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML)) r.GET("/static/{filepath:*}", handlerPublicHTML) + // Locales. r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales)) r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales)) + // Swagger. + r.GET("/api/", autheliaMiddleware(serveSwaggerHandler)) + r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS) + r.GET("/api/"+apiFile, policyCORSPublicGET.Middleware(autheliaMiddleware(serveSwaggerAPIHandler))) + r.OPTIONS("/api/"+apiFile, policyCORSPublicGET.HandleOPTIONS) + + for _, file := range swaggerFiles { + r.GET("/api/"+file, handlerPublicHTML) + } + r.GET("/api/health", autheliaMiddleware(handlers.HealthGet)) r.GET("/api/state", autheliaMiddleware(handlers.StateGet)) @@ -161,22 +175,98 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr r.GET("/debug/vars", expvarhandler.ExpvarHandler) } - r.NotFound = handleNotFound(autheliaMiddleware(serveIndexHandler)) + if providers.OpenIDConnect.Fosite != nil { + r.GET("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentGET)) + r.POST("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentPOST)) + + allowedOrigins := utils.StringSliceFromURLs(configuration.IdentityProviders.OIDC.CORS.AllowedOrigins) + + r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OpenIDConnectConfigurationWellKnownGET))) + + r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OAuthAuthorizationServerWellKnownGET))) + + r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS) + r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET))) + + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS) + r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET))) + + policyCORSAuthorization := middlewares.NewCORSPolicyBuilder(). + WithAllowedMethods("OPTIONS", "GET"). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.AuthorizationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS) + r.GET(oidc.AuthorizationPath, autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) + + // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint. + r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS) + r.GET("/api/oidc/authorize", autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET))) + + policyCORSToken := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods("OPTIONS", "POST"). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.TokenEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS) + r.POST(oidc.TokenPath, policyCORSToken.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST)))) + + policyCORSUserinfo := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods("OPTIONS", "GET", "POST"). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.UserinfoEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS) + r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) + r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo)))) + + policyCORSIntrospection := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods("OPTIONS", "POST"). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.IntrospectionEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS) + r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) + + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS) + r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST)))) + + policyCORSRevocation := middlewares.NewCORSPolicyBuilder(). + WithAllowCredentials(true). + WithAllowedMethods("OPTIONS", "POST"). + WithAllowedOrigins(allowedOrigins...). + WithEnabled(utils.IsStringInSlice(oidc.RevocationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)). + Build() + + r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS) + r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) + + // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint. + r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS) + r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST)))) + } + + r.NotFound = handlerNotFound(autheliaMiddleware(serveIndexHandler)) r.HandleMethodNotAllowed = true - r.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) { - handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed) - } + r.MethodNotAllowed = handlerMethodNotAllowed handler := middlewares.LogRequestMiddleware(r.Handler) if configuration.Server.Path != "" { handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler) } - if providers.OpenIDConnect.Fosite != nil { - handlers.RegisterOIDC(r, autheliaMiddleware) - } - return handler } @@ -185,12 +275,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov handler := registerRoutes(configuration, providers) server := &fasthttp.Server{ - ErrorHandler: autheliaErrorHandler, + ErrorHandler: handlerErrors, Handler: handler, NoDefaultServerHeader: true, ReadBufferSize: configuration.Server.ReadBufferSize, WriteBufferSize: configuration.Server.WriteBufferSize, } + logger := logging.Logger() address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port)) @@ -204,9 +295,8 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" { connectionType, connectionScheme = "TLS", schemeHTTPS - err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key) - if err != nil { + if err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key); err != nil { logger.Fatalf("unable to load certificate: %v", err) } @@ -228,14 +318,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert } - listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()) - if err != nil { + if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil { logger.Fatalf("Error initializing listener: %s", err) } } else { connectionType, connectionScheme = "non-TLS", schemeHTTP - listener, err = net.Listen("tcp", address) - if err != nil { + + if listener, err = net.Listen("tcp", address); err != nil { logger.Fatalf("Error initializing listener: %s", err) } } @@ -245,11 +334,10 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov logger.Fatalf("Could not configure healthcheck: %v", err) } - actualAddress := listener.Addr().String() if configuration.Server.Path == "" { - logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress) + logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, listener.Addr().String()) } else { - logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path) + logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, listener.Addr().String(), configuration.Server.Path) } return server, listener diff --git a/internal/server/template.go b/internal/server/template.go index b71041bc6..31b812500 100644 --- a/internal/server/template.go +++ b/internal/server/template.go @@ -48,14 +48,14 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM } } - var scheme = "https" + var scheme = schemeHTTPS if !https { proto := string(ctx.XForwardedProto()) switch proto { case "": break - case "http", "https": + case schemeHTTP, schemeHTTPS: scheme = proto } } @@ -116,7 +116,7 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er }() if host == "0.0.0.0" { - host = "localhost" + host = localhost } else if strings.Contains(host, ":") { host = "[" + host + "]" } diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 19e049200..231621f27 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -8,6 +8,8 @@ import ( "strings" "time" "unicode" + + "github.com/valyala/fasthttp" ) // IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error @@ -145,6 +147,52 @@ func IsStringSlicesDifferentFold(a, b []string) (different bool) { return isStringSlicesDifferent(a, b, IsStringInSliceFold) } +// IsURLInSlice returns true if the needle url.URL is in the []url.URL haystack. +func IsURLInSlice(needle url.URL, haystack []url.URL) (has bool) { + for i := 0; i < len(haystack); i++ { + if strings.EqualFold(needle.String(), haystack[i].String()) { + return true + } + } + + return false +} + +// StringSliceFromURLs returns a []string from a []url.URL. +func StringSliceFromURLs(urls []url.URL) []string { + result := make([]string, len(urls)) + + for i := 0; i < len(urls); i++ { + result[i] = urls[i].String() + } + + return result +} + +// URLsFromStringSlice returns a []url.URL from a []string. +func URLsFromStringSlice(urls []string) []url.URL { + var result []url.URL + + for i := 0; i < len(urls); i++ { + u, err := url.Parse(urls[i]) + if err != nil { + continue + } + + result = append(result, *u) + } + + return result +} + +// OriginFromURL returns an origin url.URL given another url.URL. +func OriginFromURL(u url.URL) (origin url.URL) { + return url.URL{ + Scheme: u.Scheme, + Host: u.Host, + } +} + // StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string. func StringSlicesDelta(before, after []string) (added, removed []string) { for _, s := range before { @@ -193,6 +241,19 @@ func StringHTMLEscape(input string) (output string) { return htmlEscaper.Replace(input) } +// JoinAndCanonicalizeHeaders join header strings by a given sep. +func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) { + for i, header := range headers { + if i != 0 { + joined = append(joined, sep...) + } + + joined = fasthttp.AppendNormalizedHeaderKey(joined, header) + } + + return joined +} + func init() { rand.Seed(time.Now().UnixNano()) } diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index c006f57e3..2c9ad981e 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -1,6 +1,7 @@ package utils import ( + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -171,3 +172,48 @@ func TestIsStringSliceContainsAny(t *testing.T) { assert.False(t, IsStringSliceContainsAny(needles, haystackOne)) assert.True(t, IsStringSliceContainsAny(needles, haystackTwo)) } + +func TestStringSliceURLConversionFuncs(t *testing.T) { + urls := URLsFromStringSlice([]string{"https://google.com", "abc", "%*()@#$J(@*#$J@#($H"}) + + require.Len(t, urls, 2) + assert.Equal(t, "https://google.com", urls[0].String()) + assert.Equal(t, "abc", urls[1].String()) + + strs := StringSliceFromURLs(urls) + + require.Len(t, strs, 2) + assert.Equal(t, "https://google.com", strs[0]) + assert.Equal(t, "abc", strs[1]) +} + +func TestIsURLInSlice(t *testing.T) { + urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"}) + + google, err := url.Parse("https://google.com") + assert.NoError(t, err) + + microsoft, err := url.Parse("https://microsoft.com") + assert.NoError(t, err) + + example, err := url.Parse("https://example.com") + assert.NoError(t, err) + + assert.True(t, IsURLInSlice(*google, urls)) + assert.False(t, IsURLInSlice(*microsoft, urls)) + assert.True(t, IsURLInSlice(*example, urls)) +} + +func TestOriginFromURL(t *testing.T) { + google, err := url.Parse("https://google.com/abc?a=123#five") + assert.NoError(t, err) + + origin := OriginFromURL(*google) + assert.Equal(t, "https://google.com", origin.String()) +} + +func TestJoinAndCanonicalizeHeaders(t *testing.T) { + result := JoinAndCanonicalizeHeaders([]byte(", "), "x-example-ONE", "X-EGG-Two") + + assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result) +}