feat(oidc): provide cors config including options handlers (#3005)

This adjusts the CORS headers appropriately for OpenID Connect. This includes responding to OPTIONS requests appropriately. Currently this is only configured to operate when the Origin scheme is HTTPS; but can easily be expanded in the future to include additional Origins.
pull/3125/head
James Elliott 2022-04-07 10:58:51 +10:00 committed by GitHub
parent a694cf851f
commit 4ebd8fdf4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1729 additions and 254 deletions

View File

@ -767,6 +767,26 @@ notifier:
## for security reasons. ## for security reasons.
# enforce_pkce: public_clients_only # enforce_pkce: public_clients_only
## Cross-Origin Resource Sharing (CORS) settings.
# cors:
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
# endpoints:
# - authorization
# - token
# - revocation
# - introspection
# - userinfo
## List of allowed origins.
## Any origin with https is permitted unless this option is configured or the
## allowed_origins_from_client_redirect_uris option is enabled.
# allowed_origins:
# - https://example.com
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
## provided they have the scheme http or https and do not have the hostname of localhost.
# allowed_origins_from_client_redirect_uris: false
## Clients is a list of known clients and their configuration. ## Clients is a list of known clients and their configuration.
# clients: # clients:
# - # -

View File

@ -35,6 +35,15 @@ identity_providers:
refresh_token_lifespan: 90m refresh_token_lifespan: 90m
enable_client_debug_messages: false enable_client_debug_messages: false
enforce_pkce: public_clients_only enforce_pkce: public_clients_only
cors:
endpoints:
- authorization
- token
- revocation
- introspection
allowed_origins:
- https://example.com
allowed_origins_from_client_redirect_uris: false
clients: clients:
- id: myapp - id: myapp
description: My Application description: My Application
@ -218,6 +227,79 @@ Allows PKCE `plain` challenges when set to `true`.
***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead. ***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead.
### cors
Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
##### endpoints
<div markdown="1">
type: list(string)
{: .label .label-config .label-purple }
default: empty
{: .label .label-config .label-blue }
required: no
{: .label .label-config .label-green }
</div>
A list of endpoints to configure with cross-origin resource sharing headers. It is recommended that the `userinfo`
option is at least in this list. The potential endpoints which this can be enabled on are as follows:
* authorization
* token
* revocation
* introspection
* userinfo
#### allowed_origins
<div markdown="1">
type: list(string)
{: .label .label-config .label-purple }
default: empty
{: .label .label-config .label-blue }
required: no
{: .label .label-config .label-green }
</div>
A list of permitted origins.
Any origin with https is permitted unless this option is configured or the allowed_origins_from_client_redirect_uris
option is enabled. This means you must configure this option manually if you want http endpoints to be permitted to
make cross-origin requests to the OpenID Connect endpoints, however this is not recommended.
Origins must only have the scheme, hostname and port, they may not have a trailing slash or path.
In addition to an Origin URI, you may specify the wildcard origin in the allowed_origins. It MUST be specified by itself
and the allowed_origins_from_client_redirect_uris MUST NOT be enabled. The wildcard origin is denoted as `*`. Examples:
```yaml
identity_providers:
oidc:
cors:
allowed_origins: "*"
```
```yaml
identity_providers:
oidc:
cors:
allowed_origins:
- "*"
```
#### allowed_origins_from_client_redirect_uris
<div markdown="1">
type: boolean
{: .label .label-config .label-purple }
default: false
{: .label .label-config .label-blue }
required: no
{: .label .label-config .label-green }
</div>
Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, provided they
have the scheme http or https and do not have the hostname of localhost.
### clients ### clients
A list of clients to configure. The options for each client are described below. A list of clients to configure. The options for each client are described below.
@ -487,22 +569,46 @@ Below is a list of the potential values we place in the claim and their meaning:
## Endpoint Implementations ## Endpoint Implementations
This is a table of the endpoints we currently support and their paths. This can be requrired information for some RP's, The following section documents the endpoints we implement and their respective paths. This information can traditionally
particularly those that don't use [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html). The paths are be discovered by relying parties that utilize [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html),
appended to the end of the primary URL used to access Authelia. For example in the Discovery example provided you access however this information may be useful for clients which do not implement this.
Authelia via https://auth.example.com, the discovery URL is https://auth.example.com/.well-known/openid-configuration.
| Endpoint | Path | The endpoints can be discovered easily by visiting the Discovery and Metadata endpoints. It is recommended regardless
|:-------------:|:---------------------------------------------:| of your version of Authelia that you utilize this version as it will always produce the correct endpoint URLs. The paths
| Discovery | [root]/.well-known/openid-configuration | for the Discovery/Metadata endpoints are part of IANA's well known registration but are also documented in a table below.
| Metadata | [root]/.well-known/oauth-authorization-server |
| JWKS | [root]/api/oidc/jwks | These tables document the endpoints we currently support and their paths in the most recent version of Authelia. The paths
| Authorization | [root]/api/oidc/authorization | are appended to the end of the primary URL used to access Authelia. The tables use the url https://auth.example.com as
| Token | [root]/api/oidc/token | an example of the Authelia root URL which is also the OpenID Connect issuer.
| Introspection | [root]/api/oidc/introspection |
| Revocation | [root]/api/oidc/revocation | ### Well Known Discovery Endpoints
| Userinfo | [root]/api/oidc/userinfo |
These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
| Endpoint | Path |
|:-------------:|:---------------------------------------------------------------:|
| Discovery | https://auth.example.com/.well-known/openid-configuration |
| Metadata | https://auth.example.com/.well-known/oauth-authorization-server |
### Discoverable Endpoints
These endpoints implement OpenID Connect elements.
| Endpoint | Path | Discovery Attribute |
|:---------------:|:-----------------------------------------------:|:----------------------:|
| JWKS | https://auth.example.com/jwks.json | jwks_uri |
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
[OpenID Connect]: https://openid.net/connect/ [OpenID Connect]: https://openid.net/connect/
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration [token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176 [RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176

View File

@ -767,6 +767,26 @@ notifier:
## for security reasons. ## for security reasons.
# enforce_pkce: public_clients_only # enforce_pkce: public_clients_only
## Cross-Origin Resource Sharing (CORS) settings.
# cors:
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
# endpoints:
# - authorization
# - token
# - revocation
# - introspection
# - userinfo
## List of allowed origins.
## Any origin with https is permitted unless this option is configured or the
## allowed_origins_from_client_redirect_uris option is enabled.
# allowed_origins:
# - https://example.com
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
## provided they have the scheme http or https and do not have the hostname of localhost.
# allowed_origins_from_client_redirect_uris: false
## Clients is a list of known clients and their configuration. ## Clients is a list of known clients and their configuration.
# clients: # clients:
# - # -

View File

@ -201,6 +201,24 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
assert.Equal(t, "example_secret value", config.Storage.EncryptionKey) assert.Equal(t, "example_secret value", config.Storage.EncryptionKey)
} }
func TestShouldLoadURLList(t *testing.T) {
testReset()
val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
assert.NoError(t, err)
validator.ValidateKeys(keys, DefaultEnvPrefix, val)
assert.Len(t, val.Errors(), 0)
assert.Len(t, val.Warnings(), 0)
require.Len(t, config.IdentityProviders.OIDC.CORS.AllowedOrigins, 2)
assert.Equal(t, "https://google.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[0].String())
assert.Equal(t, "https://example.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[1].String())
}
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) { func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
testReset() testReset()

View File

@ -1,6 +1,9 @@
package schema package schema
import "time" import (
"net/url"
"time"
)
// IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia. // IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia.
type IdentityProvidersConfiguration struct { type IdentityProvidersConfiguration struct {
@ -24,9 +27,19 @@ type OpenIDConnectConfiguration struct {
EnforcePKCE string `koanf:"enforce_pkce"` EnforcePKCE string `koanf:"enforce_pkce"`
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"` EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
Clients []OpenIDConnectClientConfiguration `koanf:"clients"` Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
} }
// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
type OpenIDConnectCORSConfiguration struct {
Endpoints []string `koanf:"endpoints"`
AllowedOrigins []url.URL `koanf:"allowed_origins"`
AllowedOriginsFromClientRedirectURIs bool `koanf:"allowed_origins_from_client_redirect_uris"`
}
// OpenIDConnectClientConfiguration configuration for an OpenID Connect client. // OpenIDConnectClientConfiguration configuration for an OpenID Connect client.
type OpenIDConnectClientConfiguration struct { type OpenIDConnectClientConfiguration struct {
ID string `koanf:"id"` ID string `koanf:"id"`

View File

@ -0,0 +1,133 @@
---
default_redirection_url: https://home.example.com:8080/
server:
host: 127.0.0.1
port: 9091
log:
level: debug
totp:
issuer: authelia.com
duo_api:
hostname: api-123456789.example.com
integration_key: ABCDEF
authentication_backend:
ldap:
url: ldap://127.0.0.1
base_dn: dc=example,dc=com
username_attribute: uid
additional_users_dn: ou=users
users_filter: (&({username_attribute}={input})(objectCategory=person)(objectClass=user))
additional_groups_dn: ou=groups
groups_filter: (&(member={dn})(objectClass=groupOfNames))
group_name_attribute: cn
mail_attribute: mail
user: cn=admin,dc=example,dc=com
access_control:
default_policy: deny
rules:
# Rules applied to everyone
- domain: public.example.com
policy: bypass
- domain: secure.example.com
policy: one_factor
# Network based rule, if not provided any network matches.
networks:
- 192.168.1.0/24
- domain: secure.example.com
policy: two_factor
- domain: [singlefactor.example.com, onefactor.example.com]
policy: one_factor
# Rules applied to 'admins' group
- domain: "mx2.mail.example.com"
subject: "group:admins"
policy: deny
- domain: "*.example.com"
subject: "group:admins"
policy: two_factor
# Rules applied to 'dev' group
- domain: dev.example.com
resources:
- "^/groups/dev/.*$"
subject: "group:dev"
policy: two_factor
# Rules applied to user 'john'
- domain: dev.example.com
resources:
- "^/users/john/.*$"
subject: "user:john"
policy: two_factor
# Rules applied to 'dev' group and user 'john'
- domain: dev.example.com
resources:
- "^/deny-all.*$"
subject: ["group:dev", "user:john"]
policy: deny
# Rules applied to user 'harry'
- domain: dev.example.com
resources:
- "^/users/harry/.*$"
subject: "user:harry"
policy: two_factor
# Rules applied to user 'bob'
- domain: "*.mail.example.com"
subject: "user:bob"
policy: two_factor
- domain: "dev.example.com"
resources:
- "^/users/bob/.*$"
subject: "user:bob"
policy: two_factor
session:
name: authelia_session
expiration: 3600000 # 1 hour
inactivity: 300000 # 5 minutes
domain: example.com
redis:
host: 127.0.0.1
port: 6379
high_availability:
sentinel_name: test
regulation:
max_retries: 3
find_time: 120
ban_time: 300
storage:
mysql:
host: 127.0.0.1
port: 3306
database: authelia
username: authelia
notifier:
smtp:
username: test
host: 127.0.0.1
port: 1025
sender: admin@example.com
disable_require_tls: true
identity_providers:
oidc:
cors:
allowed_origins:
- https://google.com
- https://example.com
...

View File

@ -123,11 +123,15 @@ const (
const ( const (
errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " + errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " +
"more clients configured" "more clients configured"
errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required" errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " + errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " +
"'public_clients_only' or 'always', but it is configured as '%s'" "'public_clients_only' or 'always', but it is configured as '%s'"
errFmtOIDCCORSInvalidOrigin = "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value '%s' as it has a %s: origins must only be scheme, hostname, and an optional port"
errFmtOIDCCORSInvalidOriginWildcard = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself"
errFmtOIDCCORSInvalidOriginWildcardWithClients = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled"
errFmtOIDCCORSInvalidEndpoint = "identity_providers: oidc: cors: option 'endpoints' contains an invalid value '%s': must be one of '%s'"
errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" + errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" +
"id's must be unique" "id's must be unique"
errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " + errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " +
@ -275,6 +279,7 @@ var validOIDCScopes = []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProf
var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"} var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}
var validOIDCResponseModes = []string{"form_post", "query", "fragment"} var validOIDCResponseModes = []string{"form_post", "query", "fragment"}
var validOIDCUserinfoAlgorithms = []string{"none", "RS256"} var validOIDCUserinfoAlgorithms = []string{"none", "RS256"}
var validOIDCCORSEndpoints = []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint}
var reKeyReplacer = regexp.MustCompile(`\[\d+]`) var reKeyReplacer = regexp.MustCompile(`\[\d+]`)
@ -471,6 +476,9 @@ var ValidKeys = []string{
"identity_providers.oidc.enable_pkce_plain_challenge", "identity_providers.oidc.enable_pkce_plain_challenge",
"identity_providers.oidc.enable_client_debug_messages", "identity_providers.oidc.enable_client_debug_messages",
"identity_providers.oidc.minimum_parameter_entropy", "identity_providers.oidc.minimum_parameter_entropy",
"identity_providers.oidc.cors.endpoints",
"identity_providers.oidc.cors.allowed_origins",
"identity_providers.oidc.cors.enable_origins_from_clients",
"identity_providers.oidc.clients", "identity_providers.oidc.clients",
"identity_providers.oidc.clients[].id", "identity_providers.oidc.clients[].id",
"identity_providers.oidc.clients[].description", "identity_providers.oidc.clients[].description",

View File

@ -49,6 +49,7 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE)) validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE))
} }
validateOIDCOptionsCORS(config, validator)
validateOIDCClients(config, validator) validateOIDCClients(config, validator)
if len(config.Clients) == 0 { if len(config.Clients) == 0 {
@ -57,6 +58,64 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
} }
} }
func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
validateOIDCOptionsCORSAllowedOrigins(config, validator)
if config.CORS.AllowedOriginsFromClientRedirectURIs {
validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config)
}
validateOIDCOptionsCORSEndpoints(config, validator)
}
func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
for _, origin := range config.CORS.AllowedOrigins {
if origin.String() == "*" {
if len(config.CORS.AllowedOrigins) != 1 {
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard))
}
if config.CORS.AllowedOriginsFromClientRedirectURIs {
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients))
}
continue
}
if origin.Path != "" {
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path"))
}
if origin.RawQuery != "" {
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string"))
}
}
}
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
for _, client := range config.Clients {
for _, redirectURI := range client.RedirectURIs {
uri, err := url.Parse(redirectURI)
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
continue
}
origin := utils.OriginFromURL(*uri)
if !utils.IsURLInSlice(origin, config.CORS.AllowedOrigins) {
config.CORS.AllowedOrigins = append(config.CORS.AllowedOrigins, origin)
}
}
}
}
func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
for _, endpoint := range config.CORS.Endpoints {
if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) {
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '")))
}
}
}
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
invalidID, duplicateIDs := false, false invalidID, duplicateIDs := false, false
@ -97,7 +156,6 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s
validateOIDCClientResponseTypes(c, config, validator) validateOIDCClientResponseTypes(c, config, validator)
validateOIDCClientResponseModes(c, config, validator) validateOIDCClientResponseModes(c, config, validator)
validateOIDDClientUserinfoAlgorithm(c, config, validator) validateOIDDClientUserinfoAlgorithm(c, config, validator)
validateOIDCClientRedirectURIs(client, validator) validateOIDCClientRedirectURIs(client, validator)
} }

View File

@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/utils"
) )
func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) { func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
@ -29,6 +31,54 @@ func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured) assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
} }
func TestShouldNotRaiseErrorWhenCORSEndpointsValid(t *testing.T) {
validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{
OIDC: &schema.OpenIDConnectConfiguration{
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
IssuerPrivateKey: "key-material",
CORS: schema.OpenIDConnectCORSConfiguration{
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint},
},
Clients: []schema.OpenIDConnectClientConfiguration{
{
ID: "example",
Secret: "example",
},
},
},
}
ValidateIdentityProviders(config, validator)
assert.Len(t, validator.Errors(), 0)
}
func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{
OIDC: &schema.OpenIDConnectConfiguration{
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
IssuerPrivateKey: "key-material",
CORS: schema.OpenIDConnectCORSConfiguration{
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint, "invalid_endpoint"},
},
Clients: []schema.OpenIDConnectClientConfiguration{
{
ID: "example",
Secret: "example",
},
},
},
}
ValidateIdentityProviders(config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'")
}
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) { func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{ config := &schema.IdentityProvidersConfiguration{
@ -47,7 +97,44 @@ func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured) assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
} }
func TestShouldRaiseErrorWhenOIDCServerIssuerPrivateKeyPathInvalid(t *testing.T) { func TestShouldRaiseErrorWhenOIDCCORSOriginsHasInvalidValues(t *testing.T) {
validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{
OIDC: &schema.OpenIDConnectConfiguration{
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
IssuerPrivateKey: "key-material",
CORS: schema.OpenIDConnectCORSConfiguration{
AllowedOrigins: utils.URLsFromStringSlice([]string{"https://example.com/", "https://site.example.com/subpath", "https://site.example.com?example=true", "*"}),
AllowedOriginsFromClientRedirectURIs: true,
},
Clients: []schema.OpenIDConnectClientConfiguration{
{
ID: "myclient",
Secret: "jk12nb3klqwmnelqkwenm",
Policy: "two_factor",
RedirectURIs: []string{"https://example.com/oauth2_callback", "https://localhost:566/callback", "http://an.example.com/callback", "file://a/file"},
},
},
},
}
ValidateIdentityProviders(config, validator)
require.Len(t, validator.Errors(), 6)
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://example.com/' as it has a path: origins must only be scheme, hostname, and an optional port")
assert.EqualError(t, validator.Errors()[1], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com/subpath' as it has a path: origins must only be scheme, hostname, and an optional port")
assert.EqualError(t, validator.Errors()[2], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com?example=true' as it has a query string: origins must only be scheme, hostname, and an optional port")
assert.EqualError(t, validator.Errors()[3], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself")
assert.EqualError(t, validator.Errors()[4], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled")
assert.EqualError(t, validator.Errors()[5], "identity_providers: oidc: client 'myclient': option 'redirect_uris' has an invalid value: redirect uri 'file://a/file' must have a scheme of 'http' or 'https' but 'file' is configured")
require.Len(t, config.OIDC.CORS.AllowedOrigins, 6)
assert.Equal(t, "*", config.OIDC.CORS.AllowedOrigins[3].String())
assert.Equal(t, "https://example.com", config.OIDC.CORS.AllowedOrigins[4].String())
}
func TestShouldRaiseErrorWhenOIDCServerNoClients(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := &schema.IdentityProvidersConfiguration{ config := &schema.IdentityProvidersConfiguration{
OIDC: &schema.OpenIDConnectConfiguration{ OIDC: &schema.OpenIDConnectConfiguration{

View File

@ -72,16 +72,6 @@ const (
auth = "auth" auth = "auth"
) )
// OIDC constants.
const (
pathLegacyOpenIDConnectAuthorization = "/api/oidc/authorize"
pathLegacyOpenIDConnectIntrospection = "/api/oidc/introspect"
pathLegacyOpenIDConnectRevocation = "/api/oidc/revoke"
// Note: If you change this const you must also do so in the frontend at web/src/services/Api.ts.
pathOpenIDConnectConsent = "/api/oidc/consent"
)
const ( const (
accept = "accept" accept = "accept"
reject = "reject" reject = "reject"

View File

@ -6,7 +6,8 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
func oidcJWKs(ctx *middlewares.AutheliaCtx) { // JSONWebKeySetGET returns the JSON Web Key Set. Used in OAuth 2.0 and OpenID Connect 1.0.
func JSONWebKeySetGET(ctx *middlewares.AutheliaCtx) {
ctx.SetContentType("application/json") ctx.SetContentType("application/json")
if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil { if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil {

View File

@ -9,7 +9,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
) )
func oidcIntrospection(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { // OAuthIntrospectionPOST handles POST requests to the OAuth 2.0 Introspection endpoint.
//
// https://datatracker.ietf.org/doc/html/rfc7662
func OAuthIntrospectionPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var ( var (
responder fosite.IntrospectionResponder responder fosite.IntrospectionResponder
err error err error

View File

@ -8,7 +8,10 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
func oidcRevocation(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { // OAuthRevocationPOST handles POST requests to the OAuth 2.0 Revocation endpoint.
//
// https://datatracker.ietf.org/doc/html/rfc7009
func OAuthRevocationPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var err error var err error
if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil { if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil {

View File

@ -16,7 +16,10 @@ import (
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
) )
func oidcAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) { // OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
//
// https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
var ( var (
requester fosite.AuthorizeRequester requester fosite.AuthorizeRequester
responder fosite.AuthorizeResponder responder fosite.AuthorizeResponder

View File

@ -7,7 +7,8 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
func oidcConsent(ctx *middlewares.AutheliaCtx) { // OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() userSession := ctx.GetSession()
if userSession.OIDCWorkflowSession == nil { if userSession.OIDCWorkflowSession == nil {
@ -39,7 +40,8 @@ func oidcConsent(ctx *middlewares.AutheliaCtx) {
} }
} }
func oidcConsentPOST(ctx *middlewares.AutheliaCtx) { // OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() userSession := ctx.GetSession()
if userSession.OIDCWorkflowSession == nil { if userSession.OIDCWorkflowSession == nil {

View File

@ -9,7 +9,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
) )
func oidcToken(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { // OpenIDConnectTokenPOST handles POST requests to the OpenID Connect 1.0 Token endpoint.
//
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
func OpenIDConnectTokenPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var ( var (
requester fosite.AccessRequester requester fosite.AccessRequester
responder fosite.AccessResponder responder fosite.AccessResponder

View File

@ -14,7 +14,10 @@ import (
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
) )
func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) { // OpenIDConnectUserinfo handles GET/POST requests to the OpenID Connect 1.0 UserInfo endpoint.
//
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
var ( var (
tokenType fosite.TokenType tokenType fosite.TokenType
requester fosite.AccessRequester requester fosite.AccessRequester
@ -97,7 +100,7 @@ func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *htt
var jti uuid.UUID var jti uuid.UUID
if jti, err = uuid.NewRandom(); err != nil { if jti, err = uuid.NewRandom(); err != nil {
ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JWT ID.")) ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JTI."))
return return
} }

View File

@ -8,7 +8,13 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) { // OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
// OpenID Connect Discovery 1.0 metadata.
//
// https://datatracker.ietf.org/doc/html/rfc5785
//
// https://openid.net/specs/openid-connect-discovery-1_0.html
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL() issuer, err := ctx.ExternalRootURL()
if err != nil { if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
@ -30,7 +36,13 @@ func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
} }
} }
func wellKnownOAuthAuthorizationServerGET(ctx *middlewares.AutheliaCtx) { // OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
// OAuth 2.0 Authorization Server Metadata (RFC8414).
//
// https://datatracker.ietf.org/doc/html/rfc5785
//
// https://datatracker.ietf.org/doc/html/rfc8414
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL() issuer, err := ctx.ExternalRootURL()
if err != nil { if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)

View File

@ -1,37 +0,0 @@
package handlers
import (
"github.com/fasthttp/router"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/oidc"
)
// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
// TODO: Add OPTIONS handler.
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))
router.POST(pathOpenIDConnectConsent, middleware(oidcConsentPOST))
router.GET(oidc.JWKsPath, middleware(oidcJWKs))
router.GET(oidc.AuthorizationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
router.GET(pathLegacyOpenIDConnectAuthorization, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
// TODO: Add OPTIONS handler.
router.POST(oidc.TokenPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcToken)))
router.POST(oidc.IntrospectionPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
router.GET(pathLegacyOpenIDConnectIntrospection, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
router.GET(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
router.POST(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
// TODO: Add OPTIONS handler.
router.POST(oidc.RevocationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
router.POST(pathLegacyOpenIDConnectRevocation, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
}

View File

@ -7,18 +7,22 @@ import (
) )
var ( var (
headerAccept = []byte(fasthttp.HeaderAccept)
headerContentLength = []byte(fasthttp.HeaderContentLength)
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto) headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost) headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor) headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith) headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
headerAccept = []byte(fasthttp.HeaderAccept)
headerXForwardedURI = []byte("X-Forwarded-URI") headerXForwardedURI = []byte("X-Forwarded-URI")
headerXOriginalURL = []byte("X-Original-URL") headerXOriginalURL = []byte("X-Original-URL")
headerXForwardedMethod = []byte("X-Forwarded-Method") headerXForwardedMethod = []byte("X-Forwarded-Method")
headerVary = []byte(fasthttp.HeaderVary) headerVary = []byte(fasthttp.HeaderVary)
headerOrigin = []byte(fasthttp.HeaderOrigin) headerAllow = []byte(fasthttp.HeaderAllow)
headerOrigin = []byte(fasthttp.HeaderOrigin)
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials) headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders) headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods) headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
@ -29,9 +33,13 @@ var (
) )
var ( var (
headerValueFalse = []byte("false") headerValueFalse = []byte("false")
headerValueMaxAge = []byte("100") headerValueTrue = []byte("true")
headerValueVary = []byte("Accept-Encoding, Origin") headerValueMaxAge = []byte("100")
headerValueVary = []byte("Accept-Encoding, Origin")
headerValueVaryWildcard = []byte("Accept-Encoding")
headerValueOriginWildcard = []byte("*")
headerValueZero = []byte("0")
) )
var ( var (
@ -40,6 +48,8 @@ var (
// UserValueKeyBaseURL is the User Value key where we store the Base URL. // UserValueKeyBaseURL is the User Value key where we store the Base URL.
UserValueKeyBaseURL = []byte("base_url") UserValueKeyBaseURL = []byte("base_url")
headerSeparator = []byte(", ")
) )
const ( const (

View File

@ -1,53 +1,347 @@
package middlewares package middlewares
import ( import (
"bytes"
"net/url" "net/url"
"strconv"
"strings" "strings"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/utils"
) )
// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well // NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied // with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
// to both Accept-Encoding and Origin. It grants the GET Request Method only. // it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler { // effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
return func(ctx *AutheliaCtx) { // requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil { // Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin) // values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
// the Access-Control-Max-Age header to 100.
//
// These behaviours can be overridden by the With methods on the returned policy.
func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
return &CORSPolicyBuilder{
enabled: true,
maxAge: 100,
}
}
// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
// configured. It aims to simplify CORS configurations.
type CORSPolicyBuilder struct {
enabled bool
varyOnly bool
varySet bool
methods []string
headers []string
origins []string
credentials bool
vary []string
maxAge int
}
// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
policy = &CORSPolicy{
enabled: b.enabled,
varyOnly: b.varyOnly,
credentials: []byte(strconv.FormatBool(b.credentials)),
origins: b.buildOrigins(),
headers: b.buildHeaders(),
vary: b.buildVary(),
}
if len(b.methods) != 0 {
policy.methods = []byte(strings.Join(b.methods, ", "))
}
if b.maxAge <= 0 {
policy.maxAge = headerValueMaxAge
} else {
policy.maxAge = []byte(strconv.Itoa(b.maxAge))
}
return policy
}
func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
if len(b.origins) != 0 {
if len(b.origins) == 1 && b.origins[0] == "*" {
origins = append(origins, []byte(b.origins[0]))
} else {
for _, origin := range b.origins {
origins = append(origins, []byte(origin))
}
} }
}
return origins
}
func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
if len(b.headers) != 0 {
h := b.headers
if b.credentials {
if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
h = append(h, fasthttp.HeaderCookie)
}
if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
h = append(h, fasthttp.HeaderAuthorization)
}
if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
h = append(h, fasthttp.HeaderProxyAuthorization)
}
}
headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
}
return headers
}
func (b CORSPolicyBuilder) buildVary() (vary []byte) {
if b.varySet {
if len(b.vary) != 0 {
vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
}
} else {
if len(b.origins) == 1 && b.origins[0] == "*" {
vary = headerValueVaryWildcard
} else {
vary = headerValueVary
}
}
return vary
}
// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
b.enabled = enabled
return b
}
// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
// that value.
func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
b.methods = methods
return b
}
// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
b.origins = origins
return b
}
// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
b.headers = headers
return b
}
// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
b.credentials = allow
return b
}
// WithVary takes a list of header strings and alters the default Vary header.
func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
b.vary = headers
b.varySet = true
return b
}
// WithVaryOnly just adds the Vary header.
func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
b.varyOnly = varyOnly
return b
}
// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
b.maxAge = age
return b
}
// CORSPolicy is a middleware that handles adding CORS headers.
type CORSPolicy struct {
enabled bool
varyOnly bool
methods []byte
headers []byte
origins [][]byte
credentials []byte
vary []byte
maxAge []byte
}
// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
// without a body. This handler should generally not be used without using WithAllowedMethods.
func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
p.handleOPTIONS(ctx)
p.handle(ctx)
}
// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
// without a body. This handler should generally not be used without using WithAllowedMethods.
func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
p.handleOPTIONS(ctx)
}
// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
return func(ctx *fasthttp.RequestCtx) {
p.handle(ctx)
next(ctx) next(ctx)
} }
} }
func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) { func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
originURL, err := url.Parse(string(origin)) if !p.enabled {
if err != nil || originURL.Scheme != "https" {
return return
} }
resp.Header.SetBytesKV(headerVary, headerValueVary) p.handleVary(ctx)
resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil { if !p.varyOnly {
requestedHeaders := strings.Split(string(headers), ",") p.handleCORS(ctx)
allowHeaders := make([]string, len(requestedHeaders)) }
}
for i, header := range requestedHeaders { func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
headerTrimmed := strings.Trim(header, " ") ctx.Response.ResetBody()
if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
allowHeaders[i] = headerTrimmed /* The OPTIONS method should not return a 204 as per the following specifications when read together:
RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
the response.
RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
or 204 (No Content).
*/
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)
if len(p.methods) != 0 {
ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
}
}
func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
if len(p.vary) != 0 {
ctx.Response.Header.SetBytesKV(headerVary, p.vary)
}
}
func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
var (
originURL *url.URL
err error
)
origin := ctx.Request.Header.PeekBytes(headerOrigin)
// Skip processing of any `https` scheme URL that has not expressly been configured.
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
return
}
var allowedOrigin []byte
switch len(p.origins) {
case 0:
allowedOrigin = origin
default:
for i := 0; i < len(p.origins); i++ {
if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
allowedOrigin = headerValueOriginWildcard
} else if bytes.Equal(p.origins[i], origin) {
allowedOrigin = origin
} }
} }
if len(allowHeaders) != 0 { if len(allowedOrigin) == 0 {
resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", "))) return
} }
} }
if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil { ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)
resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
if len(p.credentials) != 0 {
ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
}
if len(p.maxAge) != 0 {
ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
}
p.handleAllowedHeaders(ctx)
p.handleAllowedMethods(ctx)
}
func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
switch len(p.methods) {
case 0:
// TODO: It may be beneficial to be able to control this automatic behaviour.
if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
}
default:
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
}
}
func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
switch len(p.headers) {
case 0:
// TODO: It may be beneficial to be able to control this automatic behaviour.
if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
requestedHeaders := strings.Split(string(headers), ",")
allowHeaders := make([]string, 0, len(requestedHeaders))
for i := 0; i < len(requestedHeaders); i++ {
headerTrimmed := strings.Trim(requestedHeaders[i], " ")
if headerTrimmed == "*" {
continue
}
if bytes.Equal(p.credentials, headerValueTrue) ||
(!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
!strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
!strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
allowHeaders = append(allowHeaders, headerTrimmed)
}
}
if len(allowHeaders) != 0 {
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
}
}
default:
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
} }
} }

View File

@ -5,61 +5,587 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema"
) )
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) { func TestNewCORSMiddleware(t *testing.T) {
req := fasthttp.AcquireRequest() cors := NewCORSPolicyBuilder()
resp := fasthttp.Response{}
assert.Equal(t, 100, cors.maxAge)
assert.Equal(t, false, cors.credentials)
assert.Nil(t, cors.methods)
assert.Nil(t, cors.origins)
assert.Nil(t, cors.headers)
assert.Nil(t, cors.vary)
assert.False(t, cors.varyOnly)
assert.False(t, cors.varySet)
}
func TestCORSPolicyBuilder_WithEnabled(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.True(t, cors.enabled)
cors.WithEnabled(false)
assert.False(t, cors.enabled)
}
func TestCORSPolicyBuilder_WithVary(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Nil(t, cors.vary)
assert.False(t, cors.varyOnly)
assert.False(t, cors.varySet)
cors.WithVary()
assert.Nil(t, cors.vary)
assert.False(t, cors.varyOnly)
assert.True(t, cors.varySet)
cors.WithVary("Origin", "Example", "Test")
assert.Equal(t, []string{"Origin", "Example", "Test"}, cors.vary)
assert.False(t, cors.varyOnly)
assert.True(t, cors.varySet)
}
func TestCORSPolicyBuilder_WithAllowedMethods(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Nil(t, cors.methods)
cors.WithAllowedMethods("GET")
assert.Equal(t, []string{"GET"}, cors.methods)
cors.WithAllowedMethods("POST", "PATCH")
assert.Equal(t, []string{"POST", "PATCH"}, cors.methods)
cors.WithAllowedMethods()
assert.Nil(t, cors.methods)
}
func TestCORSPolicyBuilder_WithAllowedOrigins(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Nil(t, cors.origins)
cors.WithAllowedOrigins("https://google.com", "http://localhost")
assert.Equal(t, []string{"https://google.com", "http://localhost"}, cors.origins)
cors.WithAllowedOrigins()
assert.Nil(t, cors.origins)
}
func TestCORSPolicyBuilder_WithAllowedHeaders(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Nil(t, cors.headers)
cors.WithAllowedHeaders("Example", "Another")
assert.Equal(t, []string{"Example", "Another"}, cors.headers)
cors.WithAllowedHeaders()
assert.Nil(t, cors.headers)
}
func TestCORSPolicyBuilder_WithAllowCredentials(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Equal(t, false, cors.credentials)
cors.WithAllowCredentials(false)
assert.Equal(t, false, cors.credentials)
cors.WithAllowCredentials(true)
assert.Equal(t, true, cors.credentials)
}
func TestCORSPolicyBuilder_WithVaryOnly(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.False(t, cors.varyOnly)
cors.WithVaryOnly(false)
assert.False(t, cors.varyOnly)
cors.WithVaryOnly(true)
cors.WithVaryOnly(true)
}
func TestCORSPolicyBuilder_WithMaxAge(t *testing.T) {
cors := NewCORSPolicyBuilder()
assert.Equal(t, 100, cors.maxAge)
cors.WithMaxAge(20)
assert.Equal(t, 20, cors.maxAge)
cors.WithMaxAge(0)
assert.Equal(t, 0, cors.maxAge)
}
func TestCORSPolicyBuilder_HandleOPTIONS(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com") origin := []byte("https://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
corsApplyAutomaticAllowAllPolicy(req, &resp, origin) cors := NewCORSPolicyBuilder()
policy := cors.Build()
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) policy.HandleOPTIONS(ctx)
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials)) assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowedMethods("GET", "OPTIONS")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
policy = cors.Build()
policy.HandleOnlyOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithEnabled(false)
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_HandleOPTIONS_WithoutOrigin(t *testing.T) {
ctx := newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
cors := NewCORSPolicyBuilder()
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
cors.WithAllowedMethods("GET", "OPTIONS")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedOrigins(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors := NewCORSPolicyBuilder()
cors.WithAllowedOrigins("https://myapp.example.com")
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowedOrigins("https://anotherapp.example.com")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowedOrigins("*")
cors.WithAllowedMethods("GET", "OPTIONS")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_WithAllowedOrigins_DoesntOverrideVary(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors := NewCORSPolicyBuilder()
cors.WithVary("Accept-Encoding", "Origin", "Test")
cors.WithAllowedOrigins("*")
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin, Test"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_HandleOPTIONSWithVaryOnly(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors := NewCORSPolicyBuilder()
cors.WithVaryOnly(true)
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowedMethods("GET", "OPTIONS")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedHeaders(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors := NewCORSPolicyBuilder()
cors.WithAllowedHeaders("Example", "Test")
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowedMethods("GET", "OPTIONS")
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
ctx = newFastHTTPRequestCtx()
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors.WithAllowCredentials(true)
policy = cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueTrue, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("Example, Test, Cookie, Authorization, Proxy-Authorization"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func TestCORSPolicyBuilder_HandleOPTIONS_ShouldNotAllowWildcardInRequestedHeaders(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "*")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
cors := NewCORSPolicyBuilder()
policy := cors.Build()
policy.HandleOPTIONS(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
cors := NewCORSPolicyBuilder()
policy := cors.Build()
policy.handle(ctx)
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
} }
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) { func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
req := fasthttp.AcquireRequest() ctx := newFastHTTPRequestCtx()
resp := fasthttp.Response{}
origin := []byte("https://myapp.example.com") origin := []byte("https://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") ctx.Request.Header.SetBytesKV(headerOrigin, origin)
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
corsApplyAutomaticAllowAllPolicy(req, &resp, origin) cors := NewCORSPolicyBuilder()
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary)) policy := cors.Build()
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin)) policy.handle(ctx)
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge)) assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods)) assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
} }
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) { func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
req := fasthttp.AcquireRequest() ctx := newFastHTTPRequestCtx()
resp := fasthttp.Response{}
origin := []byte("http://myapp.example.com") origin := []byte("http://myapp.example.com")
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header") ctx.Request.Header.SetBytesKV(headerOrigin, origin)
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET") ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
corsApplyAutomaticAllowAllPolicy(req, &resp, origin) cors := NewCORSPolicyBuilder().WithVary()
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary)) policy := cors.Build()
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin)) policy.handle(ctx)
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge)) assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders)) assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods)) assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func Test_CORSMiddleware_AsMiddleware(t *testing.T) {
ctx := newFastHTTPRequestCtx()
origin := []byte("https://myapp.example.com")
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{})
cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS")
policy := cors.Build()
route := policy.Middleware(autheliaMiddleware(testNilHandler))
route(ctx)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
}
func testNilHandler(_ *AutheliaCtx) {}
func newFastHTTPRequestCtx() (ctx *fasthttp.RequestCtx) {
return &fasthttp.RequestCtx{
Request: fasthttp.Request{},
Response: fasthttp.Response{},
}
} }

View File

@ -19,17 +19,28 @@ const (
ClaimEmailAlts = "alt_emails" ClaimEmailAlts = "alt_emails"
) )
// Endpoints.
const (
AuthorizationEndpoint = "authorization"
TokenEndpoint = "token"
UserinfoEndpoint = "userinfo"
IntrospectionEndpoint = "introspection"
RevocationEndpoint = "revocation"
)
// Paths. // Paths.
const ( const (
WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration" WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration"
WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server" WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server"
JWKsPath = "/jwks.json"
JWKsPath = "/api/oidc/jwks" RootPath = "/api/oidc"
AuthorizationPath = "/api/oidc/authorization"
TokenPath = "/api/oidc/token" //nolint:gosec // This is not a hard coded credential, it's a path. AuthorizationPath = RootPath + "/" + AuthorizationEndpoint
IntrospectionPath = "/api/oidc/introspection" TokenPath = RootPath + "/" + TokenEndpoint
RevocationPath = "/api/oidc/revocation" UserinfoPath = RootPath + "/" + UserinfoEndpoint
UserinfoPath = "/api/oidc/userinfo" IntrospectionPath = RootPath + "/" + IntrospectionEndpoint
RevocationPath = RootPath + "/" + RevocationEndpoint
) )
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176 // Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176

View File

@ -87,7 +87,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com") disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer) assert.Equal(t, "https://example.com", disco.Issuer)
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI) assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint) assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint) assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint) assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint)
@ -173,7 +173,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
disco := provider.GetOAuth2WellKnownConfiguration("https://example.com") disco := provider.GetOAuth2WellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer) assert.Equal(t, "https://example.com", disco.Issuer)
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI) assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint) assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint) assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint) assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint)

View File

@ -37,16 +37,17 @@ var (
{name: "/api", prefix: "/api/"}, {name: "/api", prefix: "/api/"},
{name: "/.well-known", prefix: "/.well-known/"}, {name: "/.well-known", prefix: "/.well-known/"},
{name: "/static", prefix: "/static/"}, {name: "/static", prefix: "/static/"},
{name: "/locales", prefix: "/locales/"},
} }
) )
const schemeHTTP = "http"
const schemeHTTPS = "https"
const ( const (
dev = "dev" dev = "dev"
f = "false" f = "false"
t = "true" t = "true"
localhost = "localhost"
schemeHTTP = "http"
schemeHTTPS = "https"
) )
const healthCheckEnv = `# Written by Authelia Process const healthCheckEnv = `# Written by Authelia Process

View File

@ -1,28 +0,0 @@
package server
import (
"net"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/logging"
)
// Replacement for the default error handler in fasthttp.
func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) {
logger := logging.Logger()
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
// Note: Getting X-Forwarded-For or Request URI is impossible for ths error.
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
// TODO: Add X-Forwarded-For Check here.
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
} else {
// TODO: Add X-Forwarded-For Check here.
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
}
}

View File

@ -1,25 +0,0 @@
package server
import (
"strings"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/handlers"
)
func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := strings.ToLower(string(ctx.Path()))
for i := 0; i < len(httpServerDirs); i++ {
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
return
}
}
next(ctx)
}
}

View File

@ -0,0 +1,56 @@
package server
import (
"net"
"strings"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/handlers"
"github.com/authelia/authelia/v4/internal/logging"
)
// Replacement for the default error handler in fasthttp.
func handlerErrors(ctx *fasthttp.RequestCtx, err error) {
logger := logging.Logger()
switch e := err.(type) {
case *fasthttp.ErrSmallBuffer:
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
case *net.OpError:
if e.Timeout() {
// TODO: Add X-Forwarded-For Check here.
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
} else {
// TODO: Add X-Forwarded-For Check here.
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
}
default:
// TODO: Add X-Forwarded-For Check here.
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
}
}
func handlerNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := strings.ToLower(string(ctx.Path()))
for i := 0; i < len(httpServerDirs); i++ {
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
return
}
}
next(ctx)
}
}
func handlerMethodNotAllowed(ctx *fasthttp.RequestCtx) {
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
}

View File

@ -1,11 +0,0 @@
package server
import (
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/middlewares"
)
func handleOPTIONS(ctx *middlewares.AutheliaCtx) {
ctx.SetStatusCode(fasthttp.StatusNoContent)
}

View File

@ -19,10 +19,12 @@ import (
"github.com/authelia/authelia/v4/internal/handlers" "github.com/authelia/authelia/v4/internal/handlers"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/utils"
) )
// TODO: move to its own file and rename configuration -> config.
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled) rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled)
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword) resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
@ -33,37 +35,49 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment) duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment)
} }
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
handlerLocales := newLocalesEmbeddedHandler()
https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != "" https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != ""
serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https) serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
handlerLocales := newLocalesEmbeddedHandler()
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
WithAllowedMethods("OPTIONS", "GET").
WithAllowedOrigins("*").
Build()
r := router.New() r := router.New()
// Static Assets.
r.GET("/", autheliaMiddleware(serveIndexHandler)) r.GET("/", autheliaMiddleware(serveIndexHandler))
r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
for _, f := range rootFiles { for _, f := range rootFiles {
r.GET("/"+f, handlerPublicHTML) r.GET("/"+f, handlerPublicHTML)
} }
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
for _, file := range swaggerFiles {
r.GET("/api/"+file, handlerPublicHTML)
}
r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML)) r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML))
r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML)) r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML))
r.GET("/static/{filepath:*}", handlerPublicHTML) r.GET("/static/{filepath:*}", handlerPublicHTML)
// Locales.
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales)) r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales)) r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
// Swagger.
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
r.GET("/api/"+apiFile, policyCORSPublicGET.Middleware(autheliaMiddleware(serveSwaggerAPIHandler)))
r.OPTIONS("/api/"+apiFile, policyCORSPublicGET.HandleOPTIONS)
for _, file := range swaggerFiles {
r.GET("/api/"+file, handlerPublicHTML)
}
r.GET("/api/health", autheliaMiddleware(handlers.HealthGet)) r.GET("/api/health", autheliaMiddleware(handlers.HealthGet))
r.GET("/api/state", autheliaMiddleware(handlers.StateGet)) r.GET("/api/state", autheliaMiddleware(handlers.StateGet))
@ -161,22 +175,98 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
r.GET("/debug/vars", expvarhandler.ExpvarHandler) r.GET("/debug/vars", expvarhandler.ExpvarHandler)
} }
r.NotFound = handleNotFound(autheliaMiddleware(serveIndexHandler)) if providers.OpenIDConnect.Fosite != nil {
r.GET("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentGET))
r.POST("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentPOST))
allowedOrigins := utils.StringSliceFromURLs(configuration.IdentityProviders.OIDC.CORS.AllowedOrigins)
r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS)
r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OpenIDConnectConfigurationWellKnownGET)))
r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS)
r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OAuthAuthorizationServerWellKnownGET)))
r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS)
r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
WithAllowedMethods("OPTIONS", "GET").
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSlice(oidc.AuthorizationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS)
r.GET(oidc.AuthorizationPath, autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
// TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
r.GET("/api/oidc/authorize", autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
policyCORSToken := middlewares.NewCORSPolicyBuilder().
WithAllowCredentials(true).
WithAllowedMethods("OPTIONS", "POST").
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSlice(oidc.TokenEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS)
r.POST(oidc.TokenPath, policyCORSToken.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))
policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
WithAllowCredentials(true).
WithAllowedMethods("OPTIONS", "GET", "POST").
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSlice(oidc.UserinfoEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS)
r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
WithAllowCredentials(true).
WithAllowedMethods("OPTIONS", "POST").
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSlice(oidc.IntrospectionEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS)
r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
WithAllowCredentials(true).
WithAllowedMethods("OPTIONS", "POST").
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSlice(oidc.RevocationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS)
r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
}
r.NotFound = handlerNotFound(autheliaMiddleware(serveIndexHandler))
r.HandleMethodNotAllowed = true r.HandleMethodNotAllowed = true
r.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) { r.MethodNotAllowed = handlerMethodNotAllowed
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
}
handler := middlewares.LogRequestMiddleware(r.Handler) handler := middlewares.LogRequestMiddleware(r.Handler)
if configuration.Server.Path != "" { if configuration.Server.Path != "" {
handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler) handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler)
} }
if providers.OpenIDConnect.Fosite != nil {
handlers.RegisterOIDC(r, autheliaMiddleware)
}
return handler return handler
} }
@ -185,12 +275,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
handler := registerRoutes(configuration, providers) handler := registerRoutes(configuration, providers)
server := &fasthttp.Server{ server := &fasthttp.Server{
ErrorHandler: autheliaErrorHandler, ErrorHandler: handlerErrors,
Handler: handler, Handler: handler,
NoDefaultServerHeader: true, NoDefaultServerHeader: true,
ReadBufferSize: configuration.Server.ReadBufferSize, ReadBufferSize: configuration.Server.ReadBufferSize,
WriteBufferSize: configuration.Server.WriteBufferSize, WriteBufferSize: configuration.Server.WriteBufferSize,
} }
logger := logging.Logger() logger := logging.Logger()
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port)) address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
@ -204,9 +295,8 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" { if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
connectionType, connectionScheme = "TLS", schemeHTTPS connectionType, connectionScheme = "TLS", schemeHTTPS
err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)
if err != nil { if err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key); err != nil {
logger.Fatalf("unable to load certificate: %v", err) logger.Fatalf("unable to load certificate: %v", err)
} }
@ -228,14 +318,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
} }
listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()) if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
if err != nil {
logger.Fatalf("Error initializing listener: %s", err) logger.Fatalf("Error initializing listener: %s", err)
} }
} else { } else {
connectionType, connectionScheme = "non-TLS", schemeHTTP connectionType, connectionScheme = "non-TLS", schemeHTTP
listener, err = net.Listen("tcp", address)
if err != nil { if listener, err = net.Listen("tcp", address); err != nil {
logger.Fatalf("Error initializing listener: %s", err) logger.Fatalf("Error initializing listener: %s", err)
} }
} }
@ -245,11 +334,10 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
logger.Fatalf("Could not configure healthcheck: %v", err) logger.Fatalf("Could not configure healthcheck: %v", err)
} }
actualAddress := listener.Addr().String()
if configuration.Server.Path == "" { if configuration.Server.Path == "" {
logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress) logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, listener.Addr().String())
} else { } else {
logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path) logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, listener.Addr().String(), configuration.Server.Path)
} }
return server, listener return server, listener

View File

@ -48,14 +48,14 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
} }
} }
var scheme = "https" var scheme = schemeHTTPS
if !https { if !https {
proto := string(ctx.XForwardedProto()) proto := string(ctx.XForwardedProto())
switch proto { switch proto {
case "": case "":
break break
case "http", "https": case schemeHTTP, schemeHTTPS:
scheme = proto scheme = proto
} }
} }
@ -116,7 +116,7 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er
}() }()
if host == "0.0.0.0" { if host == "0.0.0.0" {
host = "localhost" host = localhost
} else if strings.Contains(host, ":") { } else if strings.Contains(host, ":") {
host = "[" + host + "]" host = "[" + host + "]"
} }

View File

@ -8,6 +8,8 @@ import (
"strings" "strings"
"time" "time"
"unicode" "unicode"
"github.com/valyala/fasthttp"
) )
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error // IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
@ -145,6 +147,52 @@ func IsStringSlicesDifferentFold(a, b []string) (different bool) {
return isStringSlicesDifferent(a, b, IsStringInSliceFold) return isStringSlicesDifferent(a, b, IsStringInSliceFold)
} }
// IsURLInSlice returns true if the needle url.URL is in the []url.URL haystack.
func IsURLInSlice(needle url.URL, haystack []url.URL) (has bool) {
for i := 0; i < len(haystack); i++ {
if strings.EqualFold(needle.String(), haystack[i].String()) {
return true
}
}
return false
}
// StringSliceFromURLs returns a []string from a []url.URL.
func StringSliceFromURLs(urls []url.URL) []string {
result := make([]string, len(urls))
for i := 0; i < len(urls); i++ {
result[i] = urls[i].String()
}
return result
}
// URLsFromStringSlice returns a []url.URL from a []string.
func URLsFromStringSlice(urls []string) []url.URL {
var result []url.URL
for i := 0; i < len(urls); i++ {
u, err := url.Parse(urls[i])
if err != nil {
continue
}
result = append(result, *u)
}
return result
}
// OriginFromURL returns an origin url.URL given another url.URL.
func OriginFromURL(u url.URL) (origin url.URL) {
return url.URL{
Scheme: u.Scheme,
Host: u.Host,
}
}
// StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string. // StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string.
func StringSlicesDelta(before, after []string) (added, removed []string) { func StringSlicesDelta(before, after []string) (added, removed []string) {
for _, s := range before { for _, s := range before {
@ -193,6 +241,19 @@ func StringHTMLEscape(input string) (output string) {
return htmlEscaper.Replace(input) return htmlEscaper.Replace(input)
} }
// JoinAndCanonicalizeHeaders join header strings by a given sep.
func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
for i, header := range headers {
if i != 0 {
joined = append(joined, sep...)
}
joined = fasthttp.AppendNormalizedHeaderKey(joined, header)
}
return joined
}
func init() { func init() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"net/url"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -171,3 +172,48 @@ func TestIsStringSliceContainsAny(t *testing.T) {
assert.False(t, IsStringSliceContainsAny(needles, haystackOne)) assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo)) assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
} }
func TestStringSliceURLConversionFuncs(t *testing.T) {
urls := URLsFromStringSlice([]string{"https://google.com", "abc", "%*()@#$J(@*#$J@#($H"})
require.Len(t, urls, 2)
assert.Equal(t, "https://google.com", urls[0].String())
assert.Equal(t, "abc", urls[1].String())
strs := StringSliceFromURLs(urls)
require.Len(t, strs, 2)
assert.Equal(t, "https://google.com", strs[0])
assert.Equal(t, "abc", strs[1])
}
func TestIsURLInSlice(t *testing.T) {
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
google, err := url.Parse("https://google.com")
assert.NoError(t, err)
microsoft, err := url.Parse("https://microsoft.com")
assert.NoError(t, err)
example, err := url.Parse("https://example.com")
assert.NoError(t, err)
assert.True(t, IsURLInSlice(*google, urls))
assert.False(t, IsURLInSlice(*microsoft, urls))
assert.True(t, IsURLInSlice(*example, urls))
}
func TestOriginFromURL(t *testing.T) {
google, err := url.Parse("https://google.com/abc?a=123#five")
assert.NoError(t, err)
origin := OriginFromURL(*google)
assert.Equal(t, "https://google.com", origin.String())
}
func TestJoinAndCanonicalizeHeaders(t *testing.T) {
result := JoinAndCanonicalizeHeaders([]byte(", "), "x-example-ONE", "X-EGG-Two")
assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result)
}