diff --git a/config.template.yml b/config.template.yml
index db21d7af5..8ea8ae488 100644
--- a/config.template.yml
+++ b/config.template.yml
@@ -767,6 +767,26 @@ notifier:
## for security reasons.
# enforce_pkce: public_clients_only
+ ## Cross-Origin Resource Sharing (CORS) settings.
+ # cors:
+ ## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
+ # endpoints:
+ # - authorization
+ # - token
+ # - revocation
+ # - introspection
+ # - userinfo
+
+ ## List of allowed origins.
+ ## Any origin with https is permitted unless this option is configured or the
+ ## allowed_origins_from_client_redirect_uris option is enabled.
+ # allowed_origins:
+ # - https://example.com
+
+ ## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
+ ## provided they have the scheme http or https and do not have the hostname of localhost.
+ # allowed_origins_from_client_redirect_uris: false
+
## Clients is a list of known clients and their configuration.
# clients:
# -
diff --git a/docs/configuration/identity-providers/oidc.md b/docs/configuration/identity-providers/oidc.md
index 1c1dc0278..f2ffd2766 100644
--- a/docs/configuration/identity-providers/oidc.md
+++ b/docs/configuration/identity-providers/oidc.md
@@ -35,6 +35,15 @@ identity_providers:
refresh_token_lifespan: 90m
enable_client_debug_messages: false
enforce_pkce: public_clients_only
+ cors:
+ endpoints:
+ - authorization
+ - token
+ - revocation
+ - introspection
+ allowed_origins:
+ - https://example.com
+ allowed_origins_from_client_redirect_uris: false
clients:
- id: myapp
description: My Application
@@ -218,6 +227,79 @@ Allows PKCE `plain` challenges when set to `true`.
***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead.
+### cors
+
+Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
+you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
+
+##### endpoints
+
+type: list(string)
+{: .label .label-config .label-purple }
+default: empty
+{: .label .label-config .label-blue }
+required: no
+{: .label .label-config .label-green }
+
+
+A list of endpoints to configure with cross-origin resource sharing headers. It is recommended that the `userinfo`
+option is at least in this list. The potential endpoints which this can be enabled on are as follows:
+
+* authorization
+* token
+* revocation
+* introspection
+* userinfo
+
+#### allowed_origins
+
+type: list(string)
+{: .label .label-config .label-purple }
+default: empty
+{: .label .label-config .label-blue }
+required: no
+{: .label .label-config .label-green }
+
+
+A list of permitted origins.
+
+Any origin with https is permitted unless this option is configured or the allowed_origins_from_client_redirect_uris
+option is enabled. This means you must configure this option manually if you want http endpoints to be permitted to
+make cross-origin requests to the OpenID Connect endpoints, however this is not recommended.
+
+Origins must only have the scheme, hostname and port, they may not have a trailing slash or path.
+
+In addition to an Origin URI, you may specify the wildcard origin in the allowed_origins. It MUST be specified by itself
+and the allowed_origins_from_client_redirect_uris MUST NOT be enabled. The wildcard origin is denoted as `*`. Examples:
+
+```yaml
+identity_providers:
+ oidc:
+ cors:
+ allowed_origins: "*"
+```
+
+```yaml
+identity_providers:
+ oidc:
+ cors:
+ allowed_origins:
+ - "*"
+```
+
+#### allowed_origins_from_client_redirect_uris
+
+type: boolean
+{: .label .label-config .label-purple }
+default: false
+{: .label .label-config .label-blue }
+required: no
+{: .label .label-config .label-green }
+
+
+Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, provided they
+have the scheme http or https and do not have the hostname of localhost.
+
### clients
A list of clients to configure. The options for each client are described below.
@@ -487,22 +569,46 @@ Below is a list of the potential values we place in the claim and their meaning:
## Endpoint Implementations
-This is a table of the endpoints we currently support and their paths. This can be requrired information for some RP's,
-particularly those that don't use [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html). The paths are
-appended to the end of the primary URL used to access Authelia. For example in the Discovery example provided you access
-Authelia via https://auth.example.com, the discovery URL is https://auth.example.com/.well-known/openid-configuration.
+The following section documents the endpoints we implement and their respective paths. This information can traditionally
+be discovered by relying parties that utilize [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html),
+however this information may be useful for clients which do not implement this.
-| Endpoint | Path |
-|:-------------:|:---------------------------------------------:|
-| Discovery | [root]/.well-known/openid-configuration |
-| Metadata | [root]/.well-known/oauth-authorization-server |
-| JWKS | [root]/api/oidc/jwks |
-| Authorization | [root]/api/oidc/authorization |
-| Token | [root]/api/oidc/token |
-| Introspection | [root]/api/oidc/introspection |
-| Revocation | [root]/api/oidc/revocation |
-| Userinfo | [root]/api/oidc/userinfo |
+The endpoints can be discovered easily by visiting the Discovery and Metadata endpoints. It is recommended regardless
+of your version of Authelia that you utilize this version as it will always produce the correct endpoint URLs. The paths
+for the Discovery/Metadata endpoints are part of IANA's well known registration but are also documented in a table below.
+
+These tables document the endpoints we currently support and their paths in the most recent version of Authelia. The paths
+are appended to the end of the primary URL used to access Authelia. The tables use the url https://auth.example.com as
+an example of the Authelia root URL which is also the OpenID Connect issuer.
+
+### Well Known Discovery Endpoints
+
+These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
+
+| Endpoint | Path |
+|:-------------:|:---------------------------------------------------------------:|
+| Discovery | https://auth.example.com/.well-known/openid-configuration |
+| Metadata | https://auth.example.com/.well-known/oauth-authorization-server |
+
+
+### Discoverable Endpoints
+
+These endpoints implement OpenID Connect elements.
+
+| Endpoint | Path | Discovery Attribute |
+|:---------------:|:-----------------------------------------------:|:----------------------:|
+| JWKS | https://auth.example.com/jwks.json | jwks_uri |
+| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
+| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
+| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
+| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
+| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
[OpenID Connect]: https://openid.net/connect/
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
-[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
\ No newline at end of file
+[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
+[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
+[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
+[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
+[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
+[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
diff --git a/internal/configuration/config.template.yml b/internal/configuration/config.template.yml
index db21d7af5..8ea8ae488 100644
--- a/internal/configuration/config.template.yml
+++ b/internal/configuration/config.template.yml
@@ -767,6 +767,26 @@ notifier:
## for security reasons.
# enforce_pkce: public_clients_only
+ ## Cross-Origin Resource Sharing (CORS) settings.
+ # cors:
+ ## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
+ # endpoints:
+ # - authorization
+ # - token
+ # - revocation
+ # - introspection
+ # - userinfo
+
+ ## List of allowed origins.
+ ## Any origin with https is permitted unless this option is configured or the
+ ## allowed_origins_from_client_redirect_uris option is enabled.
+ # allowed_origins:
+ # - https://example.com
+
+ ## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
+ ## provided they have the scheme http or https and do not have the hostname of localhost.
+ # allowed_origins_from_client_redirect_uris: false
+
## Clients is a list of known clients and their configuration.
# clients:
# -
diff --git a/internal/configuration/provider_test.go b/internal/configuration/provider_test.go
index 52cfecfce..68c0a0b02 100644
--- a/internal/configuration/provider_test.go
+++ b/internal/configuration/provider_test.go
@@ -201,6 +201,24 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
assert.Equal(t, "example_secret value", config.Storage.EncryptionKey)
}
+func TestShouldLoadURLList(t *testing.T) {
+ testReset()
+
+ val := schema.NewStructValidator()
+ keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
+
+ assert.NoError(t, err)
+
+ validator.ValidateKeys(keys, DefaultEnvPrefix, val)
+
+ assert.Len(t, val.Errors(), 0)
+ assert.Len(t, val.Warnings(), 0)
+
+ require.Len(t, config.IdentityProviders.OIDC.CORS.AllowedOrigins, 2)
+ assert.Equal(t, "https://google.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[0].String())
+ assert.Equal(t, "https://example.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[1].String())
+}
+
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
testReset()
diff --git a/internal/configuration/schema/identity_providers.go b/internal/configuration/schema/identity_providers.go
index 4ceccabb9..3828172c1 100644
--- a/internal/configuration/schema/identity_providers.go
+++ b/internal/configuration/schema/identity_providers.go
@@ -1,6 +1,9 @@
package schema
-import "time"
+import (
+ "net/url"
+ "time"
+)
// IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia.
type IdentityProvidersConfiguration struct {
@@ -24,9 +27,19 @@ type OpenIDConnectConfiguration struct {
EnforcePKCE string `koanf:"enforce_pkce"`
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
+ CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
+
Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
}
+// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
+type OpenIDConnectCORSConfiguration struct {
+ Endpoints []string `koanf:"endpoints"`
+ AllowedOrigins []url.URL `koanf:"allowed_origins"`
+
+ AllowedOriginsFromClientRedirectURIs bool `koanf:"allowed_origins_from_client_redirect_uris"`
+}
+
// OpenIDConnectClientConfiguration configuration for an OpenID Connect client.
type OpenIDConnectClientConfiguration struct {
ID string `koanf:"id"`
diff --git a/internal/configuration/test_resources/config_oidc.yml b/internal/configuration/test_resources/config_oidc.yml
new file mode 100644
index 000000000..974df781e
--- /dev/null
+++ b/internal/configuration/test_resources/config_oidc.yml
@@ -0,0 +1,133 @@
+---
+default_redirection_url: https://home.example.com:8080/
+
+server:
+ host: 127.0.0.1
+ port: 9091
+
+log:
+ level: debug
+
+totp:
+ issuer: authelia.com
+
+duo_api:
+ hostname: api-123456789.example.com
+ integration_key: ABCDEF
+
+authentication_backend:
+ ldap:
+ url: ldap://127.0.0.1
+ base_dn: dc=example,dc=com
+ username_attribute: uid
+ additional_users_dn: ou=users
+ users_filter: (&({username_attribute}={input})(objectCategory=person)(objectClass=user))
+ additional_groups_dn: ou=groups
+ groups_filter: (&(member={dn})(objectClass=groupOfNames))
+ group_name_attribute: cn
+ mail_attribute: mail
+ user: cn=admin,dc=example,dc=com
+
+access_control:
+ default_policy: deny
+
+ rules:
+ # Rules applied to everyone
+ - domain: public.example.com
+ policy: bypass
+
+ - domain: secure.example.com
+ policy: one_factor
+ # Network based rule, if not provided any network matches.
+ networks:
+ - 192.168.1.0/24
+ - domain: secure.example.com
+ policy: two_factor
+
+ - domain: [singlefactor.example.com, onefactor.example.com]
+ policy: one_factor
+
+ # Rules applied to 'admins' group
+ - domain: "mx2.mail.example.com"
+ subject: "group:admins"
+ policy: deny
+ - domain: "*.example.com"
+ subject: "group:admins"
+ policy: two_factor
+
+ # Rules applied to 'dev' group
+ - domain: dev.example.com
+ resources:
+ - "^/groups/dev/.*$"
+ subject: "group:dev"
+ policy: two_factor
+
+ # Rules applied to user 'john'
+ - domain: dev.example.com
+ resources:
+ - "^/users/john/.*$"
+ subject: "user:john"
+ policy: two_factor
+
+ # Rules applied to 'dev' group and user 'john'
+ - domain: dev.example.com
+ resources:
+ - "^/deny-all.*$"
+ subject: ["group:dev", "user:john"]
+ policy: deny
+
+ # Rules applied to user 'harry'
+ - domain: dev.example.com
+ resources:
+ - "^/users/harry/.*$"
+ subject: "user:harry"
+ policy: two_factor
+
+ # Rules applied to user 'bob'
+ - domain: "*.mail.example.com"
+ subject: "user:bob"
+ policy: two_factor
+ - domain: "dev.example.com"
+ resources:
+ - "^/users/bob/.*$"
+ subject: "user:bob"
+ policy: two_factor
+
+session:
+ name: authelia_session
+ expiration: 3600000 # 1 hour
+ inactivity: 300000 # 5 minutes
+ domain: example.com
+ redis:
+ host: 127.0.0.1
+ port: 6379
+ high_availability:
+ sentinel_name: test
+
+regulation:
+ max_retries: 3
+ find_time: 120
+ ban_time: 300
+
+storage:
+ mysql:
+ host: 127.0.0.1
+ port: 3306
+ database: authelia
+ username: authelia
+
+notifier:
+ smtp:
+ username: test
+ host: 127.0.0.1
+ port: 1025
+ sender: admin@example.com
+ disable_require_tls: true
+
+identity_providers:
+ oidc:
+ cors:
+ allowed_origins:
+ - https://google.com
+ - https://example.com
+...
diff --git a/internal/configuration/validator/const.go b/internal/configuration/validator/const.go
index 21a470609..d597badfc 100644
--- a/internal/configuration/validator/const.go
+++ b/internal/configuration/validator/const.go
@@ -123,11 +123,15 @@ const (
const (
errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " +
"more clients configured"
- errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
-
+ errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " +
"'public_clients_only' or 'always', but it is configured as '%s'"
+ errFmtOIDCCORSInvalidOrigin = "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value '%s' as it has a %s: origins must only be scheme, hostname, and an optional port"
+ errFmtOIDCCORSInvalidOriginWildcard = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself"
+ errFmtOIDCCORSInvalidOriginWildcardWithClients = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled"
+ errFmtOIDCCORSInvalidEndpoint = "identity_providers: oidc: cors: option 'endpoints' contains an invalid value '%s': must be one of '%s'"
+
errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" +
"id's must be unique"
errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " +
@@ -275,6 +279,7 @@ var validOIDCScopes = []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProf
var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}
var validOIDCResponseModes = []string{"form_post", "query", "fragment"}
var validOIDCUserinfoAlgorithms = []string{"none", "RS256"}
+var validOIDCCORSEndpoints = []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint}
var reKeyReplacer = regexp.MustCompile(`\[\d+]`)
@@ -471,6 +476,9 @@ var ValidKeys = []string{
"identity_providers.oidc.enable_pkce_plain_challenge",
"identity_providers.oidc.enable_client_debug_messages",
"identity_providers.oidc.minimum_parameter_entropy",
+ "identity_providers.oidc.cors.endpoints",
+ "identity_providers.oidc.cors.allowed_origins",
+ "identity_providers.oidc.cors.enable_origins_from_clients",
"identity_providers.oidc.clients",
"identity_providers.oidc.clients[].id",
"identity_providers.oidc.clients[].description",
diff --git a/internal/configuration/validator/identity_providers.go b/internal/configuration/validator/identity_providers.go
index 7f5fc8086..f1245bd7b 100644
--- a/internal/configuration/validator/identity_providers.go
+++ b/internal/configuration/validator/identity_providers.go
@@ -49,6 +49,7 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE))
}
+ validateOIDCOptionsCORS(config, validator)
validateOIDCClients(config, validator)
if len(config.Clients) == 0 {
@@ -57,6 +58,64 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
}
}
+func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
+ validateOIDCOptionsCORSAllowedOrigins(config, validator)
+
+ if config.CORS.AllowedOriginsFromClientRedirectURIs {
+ validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config)
+ }
+
+ validateOIDCOptionsCORSEndpoints(config, validator)
+}
+
+func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
+ for _, origin := range config.CORS.AllowedOrigins {
+ if origin.String() == "*" {
+ if len(config.CORS.AllowedOrigins) != 1 {
+ validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard))
+ }
+
+ if config.CORS.AllowedOriginsFromClientRedirectURIs {
+ validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients))
+ }
+
+ continue
+ }
+
+ if origin.Path != "" {
+ validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path"))
+ }
+
+ if origin.RawQuery != "" {
+ validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string"))
+ }
+ }
+}
+
+func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
+ for _, client := range config.Clients {
+ for _, redirectURI := range client.RedirectURIs {
+ uri, err := url.Parse(redirectURI)
+ if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
+ continue
+ }
+
+ origin := utils.OriginFromURL(*uri)
+
+ if !utils.IsURLInSlice(origin, config.CORS.AllowedOrigins) {
+ config.CORS.AllowedOrigins = append(config.CORS.AllowedOrigins, origin)
+ }
+ }
+ }
+}
+
+func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
+ for _, endpoint := range config.CORS.Endpoints {
+ if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) {
+ validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '")))
+ }
+ }
+}
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
invalidID, duplicateIDs := false, false
@@ -97,7 +156,6 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s
validateOIDCClientResponseTypes(c, config, validator)
validateOIDCClientResponseModes(c, config, validator)
validateOIDDClientUserinfoAlgorithm(c, config, validator)
-
validateOIDCClientRedirectURIs(client, validator)
}
diff --git a/internal/configuration/validator/identity_providers_test.go b/internal/configuration/validator/identity_providers_test.go
index 83a2af348..df1a191dd 100644
--- a/internal/configuration/validator/identity_providers_test.go
+++ b/internal/configuration/validator/identity_providers_test.go
@@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/utils"
)
func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
@@ -29,6 +31,54 @@ func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
}
+func TestShouldNotRaiseErrorWhenCORSEndpointsValid(t *testing.T) {
+ validator := schema.NewStructValidator()
+ config := &schema.IdentityProvidersConfiguration{
+ OIDC: &schema.OpenIDConnectConfiguration{
+ HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
+ IssuerPrivateKey: "key-material",
+ CORS: schema.OpenIDConnectCORSConfiguration{
+ Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint},
+ },
+ Clients: []schema.OpenIDConnectClientConfiguration{
+ {
+ ID: "example",
+ Secret: "example",
+ },
+ },
+ },
+ }
+
+ ValidateIdentityProviders(config, validator)
+
+ assert.Len(t, validator.Errors(), 0)
+}
+
+func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
+ validator := schema.NewStructValidator()
+ config := &schema.IdentityProvidersConfiguration{
+ OIDC: &schema.OpenIDConnectConfiguration{
+ HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
+ IssuerPrivateKey: "key-material",
+ CORS: schema.OpenIDConnectCORSConfiguration{
+ Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint, "invalid_endpoint"},
+ },
+ Clients: []schema.OpenIDConnectClientConfiguration{
+ {
+ ID: "example",
+ Secret: "example",
+ },
+ },
+ },
+ }
+
+ ValidateIdentityProviders(config, validator)
+
+ require.Len(t, validator.Errors(), 1)
+
+ assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'")
+}
+
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{
@@ -47,7 +97,44 @@ func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
}
-func TestShouldRaiseErrorWhenOIDCServerIssuerPrivateKeyPathInvalid(t *testing.T) {
+func TestShouldRaiseErrorWhenOIDCCORSOriginsHasInvalidValues(t *testing.T) {
+ validator := schema.NewStructValidator()
+
+ config := &schema.IdentityProvidersConfiguration{
+ OIDC: &schema.OpenIDConnectConfiguration{
+ HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
+ IssuerPrivateKey: "key-material",
+ CORS: schema.OpenIDConnectCORSConfiguration{
+ AllowedOrigins: utils.URLsFromStringSlice([]string{"https://example.com/", "https://site.example.com/subpath", "https://site.example.com?example=true", "*"}),
+ AllowedOriginsFromClientRedirectURIs: true,
+ },
+ Clients: []schema.OpenIDConnectClientConfiguration{
+ {
+ ID: "myclient",
+ Secret: "jk12nb3klqwmnelqkwenm",
+ Policy: "two_factor",
+ RedirectURIs: []string{"https://example.com/oauth2_callback", "https://localhost:566/callback", "http://an.example.com/callback", "file://a/file"},
+ },
+ },
+ },
+ }
+
+ ValidateIdentityProviders(config, validator)
+
+ require.Len(t, validator.Errors(), 6)
+ assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://example.com/' as it has a path: origins must only be scheme, hostname, and an optional port")
+ assert.EqualError(t, validator.Errors()[1], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com/subpath' as it has a path: origins must only be scheme, hostname, and an optional port")
+ assert.EqualError(t, validator.Errors()[2], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com?example=true' as it has a query string: origins must only be scheme, hostname, and an optional port")
+ assert.EqualError(t, validator.Errors()[3], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself")
+ assert.EqualError(t, validator.Errors()[4], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled")
+ assert.EqualError(t, validator.Errors()[5], "identity_providers: oidc: client 'myclient': option 'redirect_uris' has an invalid value: redirect uri 'file://a/file' must have a scheme of 'http' or 'https' but 'file' is configured")
+
+ require.Len(t, config.OIDC.CORS.AllowedOrigins, 6)
+ assert.Equal(t, "*", config.OIDC.CORS.AllowedOrigins[3].String())
+ assert.Equal(t, "https://example.com", config.OIDC.CORS.AllowedOrigins[4].String())
+}
+
+func TestShouldRaiseErrorWhenOIDCServerNoClients(t *testing.T) {
validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{
OIDC: &schema.OpenIDConnectConfiguration{
diff --git a/internal/handlers/const.go b/internal/handlers/const.go
index 64c174034..a43fa75b9 100644
--- a/internal/handlers/const.go
+++ b/internal/handlers/const.go
@@ -72,16 +72,6 @@ const (
auth = "auth"
)
-// OIDC constants.
-const (
- pathLegacyOpenIDConnectAuthorization = "/api/oidc/authorize"
- pathLegacyOpenIDConnectIntrospection = "/api/oidc/introspect"
- pathLegacyOpenIDConnectRevocation = "/api/oidc/revoke"
-
- // Note: If you change this const you must also do so in the frontend at web/src/services/Api.ts.
- pathOpenIDConnectConsent = "/api/oidc/consent"
-)
-
const (
accept = "accept"
reject = "reject"
diff --git a/internal/handlers/handler_oidc_jwks.go b/internal/handlers/handler_jwks.go
similarity index 67%
rename from internal/handlers/handler_oidc_jwks.go
rename to internal/handlers/handler_jwks.go
index 37e926345..14f680711 100644
--- a/internal/handlers/handler_oidc_jwks.go
+++ b/internal/handlers/handler_jwks.go
@@ -6,7 +6,8 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares"
)
-func oidcJWKs(ctx *middlewares.AutheliaCtx) {
+// JSONWebKeySetGET returns the JSON Web Key Set. Used in OAuth 2.0 and OpenID Connect 1.0.
+func JSONWebKeySetGET(ctx *middlewares.AutheliaCtx) {
ctx.SetContentType("application/json")
if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil {
diff --git a/internal/handlers/handler_oidc_introspection.go b/internal/handlers/handler_oauth_introspection.go
similarity index 80%
rename from internal/handlers/handler_oidc_introspection.go
rename to internal/handlers/handler_oauth_introspection.go
index ddc898103..331ce201d 100644
--- a/internal/handlers/handler_oidc_introspection.go
+++ b/internal/handlers/handler_oauth_introspection.go
@@ -9,7 +9,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc"
)
-func oidcIntrospection(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
+// OAuthIntrospectionPOST handles POST requests to the OAuth 2.0 Introspection endpoint.
+//
+// https://datatracker.ietf.org/doc/html/rfc7662
+func OAuthIntrospectionPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var (
responder fosite.IntrospectionResponder
err error
diff --git a/internal/handlers/handler_oidc_revocation.go b/internal/handlers/handler_oauth_revocation.go
similarity index 64%
rename from internal/handlers/handler_oidc_revocation.go
rename to internal/handlers/handler_oauth_revocation.go
index 84b4700cf..1dad867bc 100644
--- a/internal/handlers/handler_oidc_revocation.go
+++ b/internal/handlers/handler_oauth_revocation.go
@@ -8,7 +8,10 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares"
)
-func oidcRevocation(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
+// OAuthRevocationPOST handles POST requests to the OAuth 2.0 Revocation endpoint.
+//
+// https://datatracker.ietf.org/doc/html/rfc7009
+func OAuthRevocationPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var err error
if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil {
diff --git a/internal/handlers/handler_oidc_authorization.go b/internal/handlers/handler_oidc_authorization.go
index 88f407568..c5940410f 100644
--- a/internal/handlers/handler_oidc_authorization.go
+++ b/internal/handlers/handler_oidc_authorization.go
@@ -16,7 +16,10 @@ import (
"github.com/authelia/authelia/v4/internal/session"
)
-func oidcAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
+// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
+//
+// https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
+func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
var (
requester fosite.AuthorizeRequester
responder fosite.AuthorizeResponder
diff --git a/internal/handlers/handler_oidc_consent.go b/internal/handlers/handler_oidc_consent.go
index f07b52e47..b403ecb01 100644
--- a/internal/handlers/handler_oidc_consent.go
+++ b/internal/handlers/handler_oidc_consent.go
@@ -7,7 +7,8 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares"
)
-func oidcConsent(ctx *middlewares.AutheliaCtx) {
+// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
+func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
if userSession.OIDCWorkflowSession == nil {
@@ -39,7 +40,8 @@ func oidcConsent(ctx *middlewares.AutheliaCtx) {
}
}
-func oidcConsentPOST(ctx *middlewares.AutheliaCtx) {
+// OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
+func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
if userSession.OIDCWorkflowSession == nil {
diff --git a/internal/handlers/handler_oidc_token.go b/internal/handlers/handler_oidc_token.go
index 714fcb555..59a9a55eb 100644
--- a/internal/handlers/handler_oidc_token.go
+++ b/internal/handlers/handler_oidc_token.go
@@ -9,7 +9,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc"
)
-func oidcToken(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
+// OpenIDConnectTokenPOST handles POST requests to the OpenID Connect 1.0 Token endpoint.
+//
+// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
+func OpenIDConnectTokenPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var (
requester fosite.AccessRequester
responder fosite.AccessResponder
diff --git a/internal/handlers/handler_oidc_userinfo.go b/internal/handlers/handler_oidc_userinfo.go
index 6cc2a90df..1a46ec39f 100644
--- a/internal/handlers/handler_oidc_userinfo.go
+++ b/internal/handlers/handler_oidc_userinfo.go
@@ -14,7 +14,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc"
)
-func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
+// OpenIDConnectUserinfo handles GET/POST requests to the OpenID Connect 1.0 UserInfo endpoint.
+//
+// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
+func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var (
tokenType fosite.TokenType
requester fosite.AccessRequester
@@ -97,7 +100,7 @@ func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *htt
var jti uuid.UUID
if jti, err = uuid.NewRandom(); err != nil {
- ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JWT ID."))
+ ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JTI."))
return
}
diff --git a/internal/handlers/handler_oidc_wellknown.go b/internal/handlers/handler_oidc_wellknown.go
index 3a5196c23..0efd5387c 100644
--- a/internal/handlers/handler_oidc_wellknown.go
+++ b/internal/handlers/handler_oidc_wellknown.go
@@ -8,7 +8,13 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares"
)
-func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
+// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
+// OpenID Connect Discovery 1.0 metadata.
+//
+// https://datatracker.ietf.org/doc/html/rfc5785
+//
+// https://openid.net/specs/openid-connect-discovery-1_0.html
+func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
@@ -30,7 +36,13 @@ func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
}
}
-func wellKnownOAuthAuthorizationServerGET(ctx *middlewares.AutheliaCtx) {
+// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
+// OAuth 2.0 Authorization Server Metadata (RFC8414).
+//
+// https://datatracker.ietf.org/doc/html/rfc5785
+//
+// https://datatracker.ietf.org/doc/html/rfc8414
+func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
diff --git a/internal/handlers/oidc_register.go b/internal/handlers/oidc_register.go
deleted file mode 100644
index 58646da7d..000000000
--- a/internal/handlers/oidc_register.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package handlers
-
-import (
- "github.com/fasthttp/router"
-
- "github.com/authelia/authelia/v4/internal/middlewares"
- "github.com/authelia/authelia/v4/internal/oidc"
-)
-
-// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
-func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
- // TODO: Add OPTIONS handler.
- router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
- router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
-
- router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))
-
- router.POST(pathOpenIDConnectConsent, middleware(oidcConsentPOST))
-
- router.GET(oidc.JWKsPath, middleware(oidcJWKs))
-
- router.GET(oidc.AuthorizationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
- router.GET(pathLegacyOpenIDConnectAuthorization, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
-
- // TODO: Add OPTIONS handler.
- router.POST(oidc.TokenPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcToken)))
-
- router.POST(oidc.IntrospectionPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
- router.GET(pathLegacyOpenIDConnectIntrospection, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
-
- router.GET(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
- router.POST(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
-
- // TODO: Add OPTIONS handler.
- router.POST(oidc.RevocationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
- router.POST(pathLegacyOpenIDConnectRevocation, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
-}
diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go
index e83a25a36..b0475281d 100644
--- a/internal/middlewares/const.go
+++ b/internal/middlewares/const.go
@@ -7,18 +7,22 @@ import (
)
var (
+ headerAccept = []byte(fasthttp.HeaderAccept)
+ headerContentLength = []byte(fasthttp.HeaderContentLength)
+
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
- headerAccept = []byte(fasthttp.HeaderAccept)
headerXForwardedURI = []byte("X-Forwarded-URI")
headerXOriginalURL = []byte("X-Original-URL")
headerXForwardedMethod = []byte("X-Forwarded-Method")
- headerVary = []byte(fasthttp.HeaderVary)
- headerOrigin = []byte(fasthttp.HeaderOrigin)
+ headerVary = []byte(fasthttp.HeaderVary)
+ headerAllow = []byte(fasthttp.HeaderAllow)
+ headerOrigin = []byte(fasthttp.HeaderOrigin)
+
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
@@ -29,9 +33,13 @@ var (
)
var (
- headerValueFalse = []byte("false")
- headerValueMaxAge = []byte("100")
- headerValueVary = []byte("Accept-Encoding, Origin")
+ headerValueFalse = []byte("false")
+ headerValueTrue = []byte("true")
+ headerValueMaxAge = []byte("100")
+ headerValueVary = []byte("Accept-Encoding, Origin")
+ headerValueVaryWildcard = []byte("Accept-Encoding")
+ headerValueOriginWildcard = []byte("*")
+ headerValueZero = []byte("0")
)
var (
@@ -40,6 +48,8 @@ var (
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
UserValueKeyBaseURL = []byte("base_url")
+
+ headerSeparator = []byte(", ")
)
const (
diff --git a/internal/middlewares/cors.go b/internal/middlewares/cors.go
index a4152dcf1..7936e6f70 100644
--- a/internal/middlewares/cors.go
+++ b/internal/middlewares/cors.go
@@ -1,53 +1,347 @@
package middlewares
import (
+ "bytes"
"net/url"
+ "strconv"
"strings"
"github.com/valyala/fasthttp"
+
+ "github.com/authelia/authelia/v4/internal/utils"
)
-// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well
-// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied
-// to both Accept-Encoding and Origin. It grants the GET Request Method only.
-func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler {
- return func(ctx *AutheliaCtx) {
- if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil {
- corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin)
+// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
+// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
+// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
+// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
+// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
+// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
+// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
+// the Access-Control-Max-Age header to 100.
+//
+// These behaviours can be overridden by the With methods on the returned policy.
+func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
+ return &CORSPolicyBuilder{
+ enabled: true,
+ maxAge: 100,
+ }
+}
+
+// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
+// configured. It aims to simplify CORS configurations.
+type CORSPolicyBuilder struct {
+ enabled bool
+ varyOnly bool
+ varySet bool
+ methods []string
+ headers []string
+ origins []string
+ credentials bool
+ vary []string
+ maxAge int
+}
+
+// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
+func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
+ policy = &CORSPolicy{
+ enabled: b.enabled,
+ varyOnly: b.varyOnly,
+ credentials: []byte(strconv.FormatBool(b.credentials)),
+ origins: b.buildOrigins(),
+ headers: b.buildHeaders(),
+ vary: b.buildVary(),
+ }
+
+ if len(b.methods) != 0 {
+ policy.methods = []byte(strings.Join(b.methods, ", "))
+ }
+
+ if b.maxAge <= 0 {
+ policy.maxAge = headerValueMaxAge
+ } else {
+ policy.maxAge = []byte(strconv.Itoa(b.maxAge))
+ }
+
+ return policy
+}
+
+func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
+ if len(b.origins) != 0 {
+ if len(b.origins) == 1 && b.origins[0] == "*" {
+ origins = append(origins, []byte(b.origins[0]))
+ } else {
+ for _, origin := range b.origins {
+ origins = append(origins, []byte(origin))
+ }
}
+ }
+
+ return origins
+}
+
+func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
+ if len(b.headers) != 0 {
+ h := b.headers
+
+ if b.credentials {
+ if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
+ h = append(h, fasthttp.HeaderCookie)
+ }
+
+ if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
+ h = append(h, fasthttp.HeaderAuthorization)
+ }
+
+ if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
+ h = append(h, fasthttp.HeaderProxyAuthorization)
+ }
+ }
+
+ headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
+ }
+
+ return headers
+}
+
+func (b CORSPolicyBuilder) buildVary() (vary []byte) {
+ if b.varySet {
+ if len(b.vary) != 0 {
+ vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
+ }
+ } else {
+ if len(b.origins) == 1 && b.origins[0] == "*" {
+ vary = headerValueVaryWildcard
+ } else {
+ vary = headerValueVary
+ }
+ }
+
+ return vary
+}
+
+// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
+// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
+// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
+// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
+// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
+// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
+func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
+ b.enabled = enabled
+
+ return b
+}
+
+// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
+// that value.
+func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
+ b.methods = methods
+
+ return b
+}
+
+// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
+func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
+ b.origins = origins
+
+ return b
+}
+
+// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
+func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
+ b.headers = headers
+
+ return b
+}
+
+// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
+func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
+ b.credentials = allow
+
+ return b
+}
+
+// WithVary takes a list of header strings and alters the default Vary header.
+func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
+ b.vary = headers
+ b.varySet = true
+
+ return b
+}
+
+// WithVaryOnly just adds the Vary header.
+func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
+ b.varyOnly = varyOnly
+
+ return b
+}
+
+// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
+func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
+ b.maxAge = age
+
+ return b
+}
+
+// CORSPolicy is a middleware that handles adding CORS headers.
+type CORSPolicy struct {
+ enabled bool
+ varyOnly bool
+ methods []byte
+ headers []byte
+ origins [][]byte
+ credentials []byte
+ vary []byte
+ maxAge []byte
+}
+
+// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
+// without a body. This handler should generally not be used without using WithAllowedMethods.
+func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
+ p.handleOPTIONS(ctx)
+ p.handle(ctx)
+}
+
+// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
+// without a body. This handler should generally not be used without using WithAllowedMethods.
+func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
+ p.handleOPTIONS(ctx)
+}
+
+// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
+func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
+ return func(ctx *fasthttp.RequestCtx) {
+ p.handle(ctx)
next(ctx)
}
}
-func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) {
- originURL, err := url.Parse(string(origin))
- if err != nil || originURL.Scheme != "https" {
+func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
+ if !p.enabled {
return
}
- resp.Header.SetBytesKV(headerVary, headerValueVary)
- resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
- resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
- resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
+ p.handleVary(ctx)
- if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
- requestedHeaders := strings.Split(string(headers), ",")
- allowHeaders := make([]string, len(requestedHeaders))
+ if !p.varyOnly {
+ p.handleCORS(ctx)
+ }
+}
- for i, header := range requestedHeaders {
- headerTrimmed := strings.Trim(header, " ")
- if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
- allowHeaders[i] = headerTrimmed
+func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
+ ctx.Response.ResetBody()
+
+ /* The OPTIONS method should not return a 204 as per the following specifications when read together:
+
+ RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
+ A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
+ the response.
+
+ RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
+ A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
+ or 204 (No Content).
+ */
+ ctx.SetStatusCode(fasthttp.StatusOK)
+ ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)
+
+ if len(p.methods) != 0 {
+ ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
+ }
+}
+
+func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
+ if len(p.vary) != 0 {
+ ctx.Response.Header.SetBytesKV(headerVary, p.vary)
+ }
+}
+
+func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
+ var (
+ originURL *url.URL
+ err error
+ )
+
+ origin := ctx.Request.Header.PeekBytes(headerOrigin)
+
+ // Skip processing of any `https` scheme URL that has not expressly been configured.
+ if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
+ return
+ }
+
+ var allowedOrigin []byte
+
+ switch len(p.origins) {
+ case 0:
+ allowedOrigin = origin
+ default:
+ for i := 0; i < len(p.origins); i++ {
+ if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
+ allowedOrigin = headerValueOriginWildcard
+ } else if bytes.Equal(p.origins[i], origin) {
+ allowedOrigin = origin
}
}
- if len(allowHeaders) != 0 {
- resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
+ if len(allowedOrigin) == 0 {
+ return
}
}
- if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
- resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)
+
+ if len(p.credentials) != 0 {
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
+ }
+
+ if len(p.maxAge) != 0 {
+ ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
+ }
+
+ p.handleAllowedHeaders(ctx)
+ p.handleAllowedMethods(ctx)
+}
+
+func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
+ switch len(p.methods) {
+ case 0:
+ // TODO: It may be beneficial to be able to control this automatic behaviour.
+ if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
+ }
+ default:
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
+ }
+}
+
+func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
+ switch len(p.headers) {
+ case 0:
+ // TODO: It may be beneficial to be able to control this automatic behaviour.
+ if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
+ requestedHeaders := strings.Split(string(headers), ",")
+ allowHeaders := make([]string, 0, len(requestedHeaders))
+
+ for i := 0; i < len(requestedHeaders); i++ {
+ headerTrimmed := strings.Trim(requestedHeaders[i], " ")
+
+ if headerTrimmed == "*" {
+ continue
+ }
+
+ if bytes.Equal(p.credentials, headerValueTrue) ||
+ (!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
+ !strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
+ !strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
+ allowHeaders = append(allowHeaders, headerTrimmed)
+ }
+ }
+
+ if len(allowHeaders) != 0 {
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
+ }
+ }
+ default:
+ ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
}
}
diff --git a/internal/middlewares/cors_test.go b/internal/middlewares/cors_test.go
index f44106aee..9e42c103d 100644
--- a/internal/middlewares/cors_test.go
+++ b/internal/middlewares/cors_test.go
@@ -5,61 +5,587 @@ import (
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
+
+ "github.com/authelia/authelia/v4/internal/configuration/schema"
)
-func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
- req := fasthttp.AcquireRequest()
- resp := fasthttp.Response{}
+func TestNewCORSMiddleware(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Equal(t, 100, cors.maxAge)
+ assert.Equal(t, false, cors.credentials)
+
+ assert.Nil(t, cors.methods)
+ assert.Nil(t, cors.origins)
+ assert.Nil(t, cors.headers)
+ assert.Nil(t, cors.vary)
+ assert.False(t, cors.varyOnly)
+ assert.False(t, cors.varySet)
+}
+
+func TestCORSPolicyBuilder_WithEnabled(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.True(t, cors.enabled)
+
+ cors.WithEnabled(false)
+ assert.False(t, cors.enabled)
+}
+
+func TestCORSPolicyBuilder_WithVary(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Nil(t, cors.vary)
+ assert.False(t, cors.varyOnly)
+ assert.False(t, cors.varySet)
+
+ cors.WithVary()
+ assert.Nil(t, cors.vary)
+ assert.False(t, cors.varyOnly)
+ assert.True(t, cors.varySet)
+
+ cors.WithVary("Origin", "Example", "Test")
+
+ assert.Equal(t, []string{"Origin", "Example", "Test"}, cors.vary)
+ assert.False(t, cors.varyOnly)
+ assert.True(t, cors.varySet)
+}
+
+func TestCORSPolicyBuilder_WithAllowedMethods(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Nil(t, cors.methods)
+
+ cors.WithAllowedMethods("GET")
+
+ assert.Equal(t, []string{"GET"}, cors.methods)
+
+ cors.WithAllowedMethods("POST", "PATCH")
+
+ assert.Equal(t, []string{"POST", "PATCH"}, cors.methods)
+
+ cors.WithAllowedMethods()
+
+ assert.Nil(t, cors.methods)
+}
+
+func TestCORSPolicyBuilder_WithAllowedOrigins(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Nil(t, cors.origins)
+
+ cors.WithAllowedOrigins("https://google.com", "http://localhost")
+
+ assert.Equal(t, []string{"https://google.com", "http://localhost"}, cors.origins)
+
+ cors.WithAllowedOrigins()
+
+ assert.Nil(t, cors.origins)
+}
+
+func TestCORSPolicyBuilder_WithAllowedHeaders(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Nil(t, cors.headers)
+
+ cors.WithAllowedHeaders("Example", "Another")
+
+ assert.Equal(t, []string{"Example", "Another"}, cors.headers)
+
+ cors.WithAllowedHeaders()
+
+ assert.Nil(t, cors.headers)
+}
+
+func TestCORSPolicyBuilder_WithAllowCredentials(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Equal(t, false, cors.credentials)
+
+ cors.WithAllowCredentials(false)
+
+ assert.Equal(t, false, cors.credentials)
+
+ cors.WithAllowCredentials(true)
+
+ assert.Equal(t, true, cors.credentials)
+}
+
+func TestCORSPolicyBuilder_WithVaryOnly(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.False(t, cors.varyOnly)
+
+ cors.WithVaryOnly(false)
+
+ assert.False(t, cors.varyOnly)
+
+ cors.WithVaryOnly(true)
+
+ cors.WithVaryOnly(true)
+}
+
+func TestCORSPolicyBuilder_WithMaxAge(t *testing.T) {
+ cors := NewCORSPolicyBuilder()
+
+ assert.Equal(t, 100, cors.maxAge)
+
+ cors.WithMaxAge(20)
+
+ assert.Equal(t, 20, cors.maxAge)
+
+ cors.WithMaxAge(0)
+
+ assert.Equal(t, 0, cors.maxAge)
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONS(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
- req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
- corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
+ cors := NewCORSPolicyBuilder()
+ policy := cors.Build()
- assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
- assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
- assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
- assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
- assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowedMethods("GET", "OPTIONS")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ policy = cors.Build()
+ policy.HandleOnlyOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithEnabled(false)
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONS_WithoutOrigin(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+
+ cors := NewCORSPolicyBuilder()
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+
+ cors.WithAllowedMethods("GET", "OPTIONS")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedOrigins(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors := NewCORSPolicyBuilder()
+ cors.WithAllowedOrigins("https://myapp.example.com")
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowedOrigins("https://anotherapp.example.com")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowedOrigins("*")
+ cors.WithAllowedMethods("GET", "OPTIONS")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_WithAllowedOrigins_DoesntOverrideVary(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors := NewCORSPolicyBuilder()
+ cors.WithVary("Accept-Encoding", "Origin", "Test")
+ cors.WithAllowedOrigins("*")
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin, Test"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONSWithVaryOnly(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors := NewCORSPolicyBuilder()
+
+ cors.WithVaryOnly(true)
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowedMethods("GET", "OPTIONS")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedHeaders(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors := NewCORSPolicyBuilder()
+
+ cors.WithAllowedHeaders("Example", "Test")
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowedMethods("GET", "OPTIONS")
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+
+ ctx = newFastHTTPRequestCtx()
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors.WithAllowCredentials(true)
+
+ policy = cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueTrue, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("Example, Test, Cookie, Authorization, Proxy-Authorization"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func TestCORSPolicyBuilder_HandleOPTIONS_ShouldNotAllowWildcardInRequestedHeaders(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "*")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+
+ cors := NewCORSPolicyBuilder()
+
+ policy := cors.Build()
+ policy.HandleOPTIONS(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+
+ cors := NewCORSPolicyBuilder()
+
+ policy := cors.Build()
+ policy.handle(ctx)
+
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
- req := fasthttp.AcquireRequest()
- resp := fasthttp.Response{}
+ ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
- req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
- req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
- corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
+ cors := NewCORSPolicyBuilder()
- assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
- assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
- assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
- assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
- assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
- assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods))
+ policy := cors.Build()
+ policy.handle(ctx)
+
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
- req := fasthttp.AcquireRequest()
-
- resp := fasthttp.Response{}
+ ctx := newFastHTTPRequestCtx()
origin := []byte("http://myapp.example.com")
- req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
- req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
- corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
+ cors := NewCORSPolicyBuilder().WithVary()
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
- assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
+ policy := cors.Build()
+ policy.handle(ctx)
+
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func Test_CORSMiddleware_AsMiddleware(t *testing.T) {
+ ctx := newFastHTTPRequestCtx()
+
+ origin := []byte("https://myapp.example.com")
+
+ ctx.Request.Header.SetBytesKV(headerOrigin, origin)
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
+ ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
+
+ autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{})
+
+ cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS")
+
+ policy := cors.Build()
+
+ route := policy.Middleware(autheliaMiddleware(testNilHandler))
+
+ route(ctx)
+
+ assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
+ assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
+ assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
+ assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
+ assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
+ assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
+ assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
+}
+
+func testNilHandler(_ *AutheliaCtx) {}
+
+func newFastHTTPRequestCtx() (ctx *fasthttp.RequestCtx) {
+ return &fasthttp.RequestCtx{
+ Request: fasthttp.Request{},
+ Response: fasthttp.Response{},
+ }
}
diff --git a/internal/oidc/const.go b/internal/oidc/const.go
index 2d128aec5..72d395dd1 100644
--- a/internal/oidc/const.go
+++ b/internal/oidc/const.go
@@ -19,17 +19,28 @@ const (
ClaimEmailAlts = "alt_emails"
)
+// Endpoints.
+const (
+ AuthorizationEndpoint = "authorization"
+ TokenEndpoint = "token"
+ UserinfoEndpoint = "userinfo"
+ IntrospectionEndpoint = "introspection"
+ RevocationEndpoint = "revocation"
+)
+
// Paths.
const (
WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration"
WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server"
+ JWKsPath = "/jwks.json"
- JWKsPath = "/api/oidc/jwks"
- AuthorizationPath = "/api/oidc/authorization"
- TokenPath = "/api/oidc/token" //nolint:gosec // This is not a hard coded credential, it's a path.
- IntrospectionPath = "/api/oidc/introspection"
- RevocationPath = "/api/oidc/revocation"
- UserinfoPath = "/api/oidc/userinfo"
+ RootPath = "/api/oidc"
+
+ AuthorizationPath = RootPath + "/" + AuthorizationEndpoint
+ TokenPath = RootPath + "/" + TokenEndpoint
+ UserinfoPath = RootPath + "/" + UserinfoEndpoint
+ IntrospectionPath = RootPath + "/" + IntrospectionEndpoint
+ RevocationPath = RootPath + "/" + RevocationEndpoint
)
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176
diff --git a/internal/oidc/provider_test.go b/internal/oidc/provider_test.go
index c8bfba7d1..a6f6e0cb7 100644
--- a/internal/oidc/provider_test.go
+++ b/internal/oidc/provider_test.go
@@ -87,7 +87,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer)
- assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
+ assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint)
@@ -173,7 +173,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
disco := provider.GetOAuth2WellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer)
- assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
+ assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint)
diff --git a/internal/server/const.go b/internal/server/const.go
index a00d6b880..098a5bea8 100644
--- a/internal/server/const.go
+++ b/internal/server/const.go
@@ -37,16 +37,17 @@ var (
{name: "/api", prefix: "/api/"},
{name: "/.well-known", prefix: "/.well-known/"},
{name: "/static", prefix: "/static/"},
+ {name: "/locales", prefix: "/locales/"},
}
)
-const schemeHTTP = "http"
-const schemeHTTPS = "https"
-
const (
- dev = "dev"
- f = "false"
- t = "true"
+ dev = "dev"
+ f = "false"
+ t = "true"
+ localhost = "localhost"
+ schemeHTTP = "http"
+ schemeHTTPS = "https"
)
const healthCheckEnv = `# Written by Authelia Process
diff --git a/internal/server/error_handler.go b/internal/server/error_handler.go
deleted file mode 100644
index 36b93e5b2..000000000
--- a/internal/server/error_handler.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package server
-
-import (
- "net"
-
- "github.com/valyala/fasthttp"
-
- "github.com/authelia/authelia/v4/internal/logging"
-)
-
-// Replacement for the default error handler in fasthttp.
-func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) {
- logger := logging.Logger()
-
- if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
- // Note: Getting X-Forwarded-For or Request URI is impossible for ths error.
- logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
- ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
- } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
- // TODO: Add X-Forwarded-For Check here.
- logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
- ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
- } else {
- // TODO: Add X-Forwarded-For Check here.
- logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
- ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
- }
-}
diff --git a/internal/server/handler_notfound.go b/internal/server/handler_notfound.go
deleted file mode 100644
index c8002daa4..000000000
--- a/internal/server/handler_notfound.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package server
-
-import (
- "strings"
-
- "github.com/valyala/fasthttp"
-
- "github.com/authelia/authelia/v4/internal/handlers"
-)
-
-func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
- return func(ctx *fasthttp.RequestCtx) {
- path := strings.ToLower(string(ctx.Path()))
-
- for i := 0; i < len(httpServerDirs); i++ {
- if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
- handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
-
- return
- }
- }
-
- next(ctx)
- }
-}
diff --git a/internal/server/handlers.go b/internal/server/handlers.go
new file mode 100644
index 000000000..1beb151ac
--- /dev/null
+++ b/internal/server/handlers.go
@@ -0,0 +1,56 @@
+package server
+
+import (
+ "net"
+ "strings"
+
+ "github.com/valyala/fasthttp"
+
+ "github.com/authelia/authelia/v4/internal/handlers"
+ "github.com/authelia/authelia/v4/internal/logging"
+)
+
+// Replacement for the default error handler in fasthttp.
+func handlerErrors(ctx *fasthttp.RequestCtx, err error) {
+ logger := logging.Logger()
+
+ switch e := err.(type) {
+ case *fasthttp.ErrSmallBuffer:
+ logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
+ ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
+ case *net.OpError:
+ if e.Timeout() {
+ // TODO: Add X-Forwarded-For Check here.
+ logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
+ ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
+ } else {
+ // TODO: Add X-Forwarded-For Check here.
+ logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
+ ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
+ }
+ default:
+ // TODO: Add X-Forwarded-For Check here.
+ logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
+ ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
+ }
+}
+
+func handlerNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
+ return func(ctx *fasthttp.RequestCtx) {
+ path := strings.ToLower(string(ctx.Path()))
+
+ for i := 0; i < len(httpServerDirs); i++ {
+ if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
+ handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
+
+ return
+ }
+ }
+
+ next(ctx)
+ }
+}
+
+func handlerMethodNotAllowed(ctx *fasthttp.RequestCtx) {
+ handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
+}
diff --git a/internal/server/options_handler.go b/internal/server/options_handler.go
deleted file mode 100644
index 7ebd15d3c..000000000
--- a/internal/server/options_handler.go
+++ /dev/null
@@ -1,11 +0,0 @@
-package server
-
-import (
- "github.com/valyala/fasthttp"
-
- "github.com/authelia/authelia/v4/internal/middlewares"
-)
-
-func handleOPTIONS(ctx *middlewares.AutheliaCtx) {
- ctx.SetStatusCode(fasthttp.StatusNoContent)
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index 9e40679af..985fd5b20 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -19,10 +19,12 @@ import (
"github.com/authelia/authelia/v4/internal/handlers"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/utils"
)
+// TODO: move to its own file and rename configuration -> config.
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
- autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled)
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
@@ -33,37 +35,49 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment)
}
- handlerPublicHTML := newPublicHTMLEmbeddedHandler()
- handlerLocales := newLocalesEmbeddedHandler()
-
https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != ""
serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
+ handlerPublicHTML := newPublicHTMLEmbeddedHandler()
+ handlerLocales := newLocalesEmbeddedHandler()
+
+ autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
+
+ policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods("OPTIONS", "GET").
+ WithAllowedOrigins("*").
+ Build()
+
r := router.New()
+
+ // Static Assets.
r.GET("/", autheliaMiddleware(serveIndexHandler))
- r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
for _, f := range rootFiles {
r.GET("/"+f, handlerPublicHTML)
}
- r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
- r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
-
- for _, file := range swaggerFiles {
- r.GET("/api/"+file, handlerPublicHTML)
- }
-
r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML))
r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML))
r.GET("/static/{filepath:*}", handlerPublicHTML)
+ // Locales.
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
+ // Swagger.
+ r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
+ r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
+ r.GET("/api/"+apiFile, policyCORSPublicGET.Middleware(autheliaMiddleware(serveSwaggerAPIHandler)))
+ r.OPTIONS("/api/"+apiFile, policyCORSPublicGET.HandleOPTIONS)
+
+ for _, file := range swaggerFiles {
+ r.GET("/api/"+file, handlerPublicHTML)
+ }
+
r.GET("/api/health", autheliaMiddleware(handlers.HealthGet))
r.GET("/api/state", autheliaMiddleware(handlers.StateGet))
@@ -161,22 +175,98 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
}
- r.NotFound = handleNotFound(autheliaMiddleware(serveIndexHandler))
+ if providers.OpenIDConnect.Fosite != nil {
+ r.GET("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentGET))
+ r.POST("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentPOST))
+
+ allowedOrigins := utils.StringSliceFromURLs(configuration.IdentityProviders.OIDC.CORS.AllowedOrigins)
+
+ r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OpenIDConnectConfigurationWellKnownGET)))
+
+ r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OAuthAuthorizationServerWellKnownGET)))
+
+ r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS)
+ r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
+
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
+ r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
+
+ policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
+ WithAllowedMethods("OPTIONS", "GET").
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.AuthorizationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS)
+ r.GET(oidc.AuthorizationPath, autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
+
+ // TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
+ r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
+ r.GET("/api/oidc/authorize", autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
+
+ policyCORSToken := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods("OPTIONS", "POST").
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.TokenEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS)
+ r.POST(oidc.TokenPath, policyCORSToken.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))
+
+ policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods("OPTIONS", "GET", "POST").
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.UserinfoEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS)
+ r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
+ r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
+
+ policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods("OPTIONS", "POST").
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.IntrospectionEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS)
+ r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
+
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
+ r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
+
+ policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
+ WithAllowCredentials(true).
+ WithAllowedMethods("OPTIONS", "POST").
+ WithAllowedOrigins(allowedOrigins...).
+ WithEnabled(utils.IsStringInSlice(oidc.RevocationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
+ Build()
+
+ r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS)
+ r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
+
+ // TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
+ r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
+ r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
+ }
+
+ r.NotFound = handlerNotFound(autheliaMiddleware(serveIndexHandler))
r.HandleMethodNotAllowed = true
- r.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) {
- handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
- }
+ r.MethodNotAllowed = handlerMethodNotAllowed
handler := middlewares.LogRequestMiddleware(r.Handler)
if configuration.Server.Path != "" {
handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler)
}
- if providers.OpenIDConnect.Fosite != nil {
- handlers.RegisterOIDC(r, autheliaMiddleware)
- }
-
return handler
}
@@ -185,12 +275,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
handler := registerRoutes(configuration, providers)
server := &fasthttp.Server{
- ErrorHandler: autheliaErrorHandler,
+ ErrorHandler: handlerErrors,
Handler: handler,
NoDefaultServerHeader: true,
ReadBufferSize: configuration.Server.ReadBufferSize,
WriteBufferSize: configuration.Server.WriteBufferSize,
}
+
logger := logging.Logger()
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
@@ -204,9 +295,8 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
connectionType, connectionScheme = "TLS", schemeHTTPS
- err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)
- if err != nil {
+ if err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key); err != nil {
logger.Fatalf("unable to load certificate: %v", err)
}
@@ -228,14 +318,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
- listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone())
- if err != nil {
+ if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
logger.Fatalf("Error initializing listener: %s", err)
}
} else {
connectionType, connectionScheme = "non-TLS", schemeHTTP
- listener, err = net.Listen("tcp", address)
- if err != nil {
+
+ if listener, err = net.Listen("tcp", address); err != nil {
logger.Fatalf("Error initializing listener: %s", err)
}
}
@@ -245,11 +334,10 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
logger.Fatalf("Could not configure healthcheck: %v", err)
}
- actualAddress := listener.Addr().String()
if configuration.Server.Path == "" {
- logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress)
+ logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, listener.Addr().String())
} else {
- logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path)
+ logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, listener.Addr().String(), configuration.Server.Path)
}
return server, listener
diff --git a/internal/server/template.go b/internal/server/template.go
index b71041bc6..31b812500 100644
--- a/internal/server/template.go
+++ b/internal/server/template.go
@@ -48,14 +48,14 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
}
}
- var scheme = "https"
+ var scheme = schemeHTTPS
if !https {
proto := string(ctx.XForwardedProto())
switch proto {
case "":
break
- case "http", "https":
+ case schemeHTTP, schemeHTTPS:
scheme = proto
}
}
@@ -116,7 +116,7 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er
}()
if host == "0.0.0.0" {
- host = "localhost"
+ host = localhost
} else if strings.Contains(host, ":") {
host = "[" + host + "]"
}
diff --git a/internal/utils/strings.go b/internal/utils/strings.go
index 19e049200..231621f27 100644
--- a/internal/utils/strings.go
+++ b/internal/utils/strings.go
@@ -8,6 +8,8 @@ import (
"strings"
"time"
"unicode"
+
+ "github.com/valyala/fasthttp"
)
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
@@ -145,6 +147,52 @@ func IsStringSlicesDifferentFold(a, b []string) (different bool) {
return isStringSlicesDifferent(a, b, IsStringInSliceFold)
}
+// IsURLInSlice returns true if the needle url.URL is in the []url.URL haystack.
+func IsURLInSlice(needle url.URL, haystack []url.URL) (has bool) {
+ for i := 0; i < len(haystack); i++ {
+ if strings.EqualFold(needle.String(), haystack[i].String()) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// StringSliceFromURLs returns a []string from a []url.URL.
+func StringSliceFromURLs(urls []url.URL) []string {
+ result := make([]string, len(urls))
+
+ for i := 0; i < len(urls); i++ {
+ result[i] = urls[i].String()
+ }
+
+ return result
+}
+
+// URLsFromStringSlice returns a []url.URL from a []string.
+func URLsFromStringSlice(urls []string) []url.URL {
+ var result []url.URL
+
+ for i := 0; i < len(urls); i++ {
+ u, err := url.Parse(urls[i])
+ if err != nil {
+ continue
+ }
+
+ result = append(result, *u)
+ }
+
+ return result
+}
+
+// OriginFromURL returns an origin url.URL given another url.URL.
+func OriginFromURL(u url.URL) (origin url.URL) {
+ return url.URL{
+ Scheme: u.Scheme,
+ Host: u.Host,
+ }
+}
+
// StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string.
func StringSlicesDelta(before, after []string) (added, removed []string) {
for _, s := range before {
@@ -193,6 +241,19 @@ func StringHTMLEscape(input string) (output string) {
return htmlEscaper.Replace(input)
}
+// JoinAndCanonicalizeHeaders join header strings by a given sep.
+func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
+ for i, header := range headers {
+ if i != 0 {
+ joined = append(joined, sep...)
+ }
+
+ joined = fasthttp.AppendNormalizedHeaderKey(joined, header)
+ }
+
+ return joined
+}
+
func init() {
rand.Seed(time.Now().UnixNano())
}
diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go
index c006f57e3..2c9ad981e 100644
--- a/internal/utils/strings_test.go
+++ b/internal/utils/strings_test.go
@@ -1,6 +1,7 @@
package utils
import (
+ "net/url"
"testing"
"github.com/stretchr/testify/assert"
@@ -171,3 +172,48 @@ func TestIsStringSliceContainsAny(t *testing.T) {
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
}
+
+func TestStringSliceURLConversionFuncs(t *testing.T) {
+ urls := URLsFromStringSlice([]string{"https://google.com", "abc", "%*()@#$J(@*#$J@#($H"})
+
+ require.Len(t, urls, 2)
+ assert.Equal(t, "https://google.com", urls[0].String())
+ assert.Equal(t, "abc", urls[1].String())
+
+ strs := StringSliceFromURLs(urls)
+
+ require.Len(t, strs, 2)
+ assert.Equal(t, "https://google.com", strs[0])
+ assert.Equal(t, "abc", strs[1])
+}
+
+func TestIsURLInSlice(t *testing.T) {
+ urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
+
+ google, err := url.Parse("https://google.com")
+ assert.NoError(t, err)
+
+ microsoft, err := url.Parse("https://microsoft.com")
+ assert.NoError(t, err)
+
+ example, err := url.Parse("https://example.com")
+ assert.NoError(t, err)
+
+ assert.True(t, IsURLInSlice(*google, urls))
+ assert.False(t, IsURLInSlice(*microsoft, urls))
+ assert.True(t, IsURLInSlice(*example, urls))
+}
+
+func TestOriginFromURL(t *testing.T) {
+ google, err := url.Parse("https://google.com/abc?a=123#five")
+ assert.NoError(t, err)
+
+ origin := OriginFromURL(*google)
+ assert.Equal(t, "https://google.com", origin.String())
+}
+
+func TestJoinAndCanonicalizeHeaders(t *testing.T) {
+ result := JoinAndCanonicalizeHeaders([]byte(", "), "x-example-ONE", "X-EGG-Two")
+
+ assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result)
+}