feat(oidc): provide cors config including options handlers (#3005)
This adjusts the CORS headers appropriately for OpenID Connect. This includes responding to OPTIONS requests appropriately. Currently this is only configured to operate when the Origin scheme is HTTPS; but can easily be expanded in the future to include additional Origins.pull/3125/head
parent
a694cf851f
commit
4ebd8fdf4e
|
@ -767,6 +767,26 @@ notifier:
|
||||||
## for security reasons.
|
## for security reasons.
|
||||||
# enforce_pkce: public_clients_only
|
# enforce_pkce: public_clients_only
|
||||||
|
|
||||||
|
## Cross-Origin Resource Sharing (CORS) settings.
|
||||||
|
# cors:
|
||||||
|
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
|
||||||
|
# endpoints:
|
||||||
|
# - authorization
|
||||||
|
# - token
|
||||||
|
# - revocation
|
||||||
|
# - introspection
|
||||||
|
# - userinfo
|
||||||
|
|
||||||
|
## List of allowed origins.
|
||||||
|
## Any origin with https is permitted unless this option is configured or the
|
||||||
|
## allowed_origins_from_client_redirect_uris option is enabled.
|
||||||
|
# allowed_origins:
|
||||||
|
# - https://example.com
|
||||||
|
|
||||||
|
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
|
||||||
|
## provided they have the scheme http or https and do not have the hostname of localhost.
|
||||||
|
# allowed_origins_from_client_redirect_uris: false
|
||||||
|
|
||||||
## Clients is a list of known clients and their configuration.
|
## Clients is a list of known clients and their configuration.
|
||||||
# clients:
|
# clients:
|
||||||
# -
|
# -
|
||||||
|
|
|
@ -35,6 +35,15 @@ identity_providers:
|
||||||
refresh_token_lifespan: 90m
|
refresh_token_lifespan: 90m
|
||||||
enable_client_debug_messages: false
|
enable_client_debug_messages: false
|
||||||
enforce_pkce: public_clients_only
|
enforce_pkce: public_clients_only
|
||||||
|
cors:
|
||||||
|
endpoints:
|
||||||
|
- authorization
|
||||||
|
- token
|
||||||
|
- revocation
|
||||||
|
- introspection
|
||||||
|
allowed_origins:
|
||||||
|
- https://example.com
|
||||||
|
allowed_origins_from_client_redirect_uris: false
|
||||||
clients:
|
clients:
|
||||||
- id: myapp
|
- id: myapp
|
||||||
description: My Application
|
description: My Application
|
||||||
|
@ -218,6 +227,79 @@ Allows PKCE `plain` challenges when set to `true`.
|
||||||
|
|
||||||
***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead.
|
***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead.
|
||||||
|
|
||||||
|
### cors
|
||||||
|
|
||||||
|
Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
|
||||||
|
you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
|
||||||
|
|
||||||
|
##### endpoints
|
||||||
|
<div markdown="1">
|
||||||
|
type: list(string)
|
||||||
|
{: .label .label-config .label-purple }
|
||||||
|
default: empty
|
||||||
|
{: .label .label-config .label-blue }
|
||||||
|
required: no
|
||||||
|
{: .label .label-config .label-green }
|
||||||
|
</div>
|
||||||
|
|
||||||
|
A list of endpoints to configure with cross-origin resource sharing headers. It is recommended that the `userinfo`
|
||||||
|
option is at least in this list. The potential endpoints which this can be enabled on are as follows:
|
||||||
|
|
||||||
|
* authorization
|
||||||
|
* token
|
||||||
|
* revocation
|
||||||
|
* introspection
|
||||||
|
* userinfo
|
||||||
|
|
||||||
|
#### allowed_origins
|
||||||
|
<div markdown="1">
|
||||||
|
type: list(string)
|
||||||
|
{: .label .label-config .label-purple }
|
||||||
|
default: empty
|
||||||
|
{: .label .label-config .label-blue }
|
||||||
|
required: no
|
||||||
|
{: .label .label-config .label-green }
|
||||||
|
</div>
|
||||||
|
|
||||||
|
A list of permitted origins.
|
||||||
|
|
||||||
|
Any origin with https is permitted unless this option is configured or the allowed_origins_from_client_redirect_uris
|
||||||
|
option is enabled. This means you must configure this option manually if you want http endpoints to be permitted to
|
||||||
|
make cross-origin requests to the OpenID Connect endpoints, however this is not recommended.
|
||||||
|
|
||||||
|
Origins must only have the scheme, hostname and port, they may not have a trailing slash or path.
|
||||||
|
|
||||||
|
In addition to an Origin URI, you may specify the wildcard origin in the allowed_origins. It MUST be specified by itself
|
||||||
|
and the allowed_origins_from_client_redirect_uris MUST NOT be enabled. The wildcard origin is denoted as `*`. Examples:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
identity_providers:
|
||||||
|
oidc:
|
||||||
|
cors:
|
||||||
|
allowed_origins: "*"
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
identity_providers:
|
||||||
|
oidc:
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- "*"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### allowed_origins_from_client_redirect_uris
|
||||||
|
<div markdown="1">
|
||||||
|
type: boolean
|
||||||
|
{: .label .label-config .label-purple }
|
||||||
|
default: false
|
||||||
|
{: .label .label-config .label-blue }
|
||||||
|
required: no
|
||||||
|
{: .label .label-config .label-green }
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, provided they
|
||||||
|
have the scheme http or https and do not have the hostname of localhost.
|
||||||
|
|
||||||
### clients
|
### clients
|
||||||
|
|
||||||
A list of clients to configure. The options for each client are described below.
|
A list of clients to configure. The options for each client are described below.
|
||||||
|
@ -487,22 +569,46 @@ Below is a list of the potential values we place in the claim and their meaning:
|
||||||
|
|
||||||
## Endpoint Implementations
|
## Endpoint Implementations
|
||||||
|
|
||||||
This is a table of the endpoints we currently support and their paths. This can be requrired information for some RP's,
|
The following section documents the endpoints we implement and their respective paths. This information can traditionally
|
||||||
particularly those that don't use [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html). The paths are
|
be discovered by relying parties that utilize [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html),
|
||||||
appended to the end of the primary URL used to access Authelia. For example in the Discovery example provided you access
|
however this information may be useful for clients which do not implement this.
|
||||||
Authelia via https://auth.example.com, the discovery URL is https://auth.example.com/.well-known/openid-configuration.
|
|
||||||
|
The endpoints can be discovered easily by visiting the Discovery and Metadata endpoints. It is recommended regardless
|
||||||
|
of your version of Authelia that you utilize this version as it will always produce the correct endpoint URLs. The paths
|
||||||
|
for the Discovery/Metadata endpoints are part of IANA's well known registration but are also documented in a table below.
|
||||||
|
|
||||||
|
These tables document the endpoints we currently support and their paths in the most recent version of Authelia. The paths
|
||||||
|
are appended to the end of the primary URL used to access Authelia. The tables use the url https://auth.example.com as
|
||||||
|
an example of the Authelia root URL which is also the OpenID Connect issuer.
|
||||||
|
|
||||||
|
### Well Known Discovery Endpoints
|
||||||
|
|
||||||
|
These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
|
||||||
|
|
||||||
| Endpoint | Path |
|
| Endpoint | Path |
|
||||||
|:-------------:|:---------------------------------------------:|
|
|:-------------:|:---------------------------------------------------------------:|
|
||||||
| Discovery | [root]/.well-known/openid-configuration |
|
| Discovery | https://auth.example.com/.well-known/openid-configuration |
|
||||||
| Metadata | [root]/.well-known/oauth-authorization-server |
|
| Metadata | https://auth.example.com/.well-known/oauth-authorization-server |
|
||||||
| JWKS | [root]/api/oidc/jwks |
|
|
||||||
| Authorization | [root]/api/oidc/authorization |
|
|
||||||
| Token | [root]/api/oidc/token |
|
### Discoverable Endpoints
|
||||||
| Introspection | [root]/api/oidc/introspection |
|
|
||||||
| Revocation | [root]/api/oidc/revocation |
|
These endpoints implement OpenID Connect elements.
|
||||||
| Userinfo | [root]/api/oidc/userinfo |
|
|
||||||
|
| Endpoint | Path | Discovery Attribute |
|
||||||
|
|:---------------:|:-----------------------------------------------:|:----------------------:|
|
||||||
|
| JWKS | https://auth.example.com/jwks.json | jwks_uri |
|
||||||
|
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
|
||||||
|
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
|
||||||
|
| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
|
||||||
|
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
|
||||||
|
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
|
||||||
|
|
||||||
[OpenID Connect]: https://openid.net/connect/
|
[OpenID Connect]: https://openid.net/connect/
|
||||||
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
|
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
|
||||||
|
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||||
|
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||||
|
[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||||
|
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
|
||||||
|
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
|
||||||
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
|
@ -767,6 +767,26 @@ notifier:
|
||||||
## for security reasons.
|
## for security reasons.
|
||||||
# enforce_pkce: public_clients_only
|
# enforce_pkce: public_clients_only
|
||||||
|
|
||||||
|
## Cross-Origin Resource Sharing (CORS) settings.
|
||||||
|
# cors:
|
||||||
|
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
|
||||||
|
# endpoints:
|
||||||
|
# - authorization
|
||||||
|
# - token
|
||||||
|
# - revocation
|
||||||
|
# - introspection
|
||||||
|
# - userinfo
|
||||||
|
|
||||||
|
## List of allowed origins.
|
||||||
|
## Any origin with https is permitted unless this option is configured or the
|
||||||
|
## allowed_origins_from_client_redirect_uris option is enabled.
|
||||||
|
# allowed_origins:
|
||||||
|
# - https://example.com
|
||||||
|
|
||||||
|
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
|
||||||
|
## provided they have the scheme http or https and do not have the hostname of localhost.
|
||||||
|
# allowed_origins_from_client_redirect_uris: false
|
||||||
|
|
||||||
## Clients is a list of known clients and their configuration.
|
## Clients is a list of known clients and their configuration.
|
||||||
# clients:
|
# clients:
|
||||||
# -
|
# -
|
||||||
|
|
|
@ -201,6 +201,24 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
|
||||||
assert.Equal(t, "example_secret value", config.Storage.EncryptionKey)
|
assert.Equal(t, "example_secret value", config.Storage.EncryptionKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldLoadURLList(t *testing.T) {
|
||||||
|
testReset()
|
||||||
|
|
||||||
|
val := schema.NewStructValidator()
|
||||||
|
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
validator.ValidateKeys(keys, DefaultEnvPrefix, val)
|
||||||
|
|
||||||
|
assert.Len(t, val.Errors(), 0)
|
||||||
|
assert.Len(t, val.Warnings(), 0)
|
||||||
|
|
||||||
|
require.Len(t, config.IdentityProviders.OIDC.CORS.AllowedOrigins, 2)
|
||||||
|
assert.Equal(t, "https://google.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[0].String())
|
||||||
|
assert.Equal(t, "https://example.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[1].String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
|
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
|
||||||
testReset()
|
testReset()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia.
|
// IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia.
|
||||||
type IdentityProvidersConfiguration struct {
|
type IdentityProvidersConfiguration struct {
|
||||||
|
@ -24,9 +27,19 @@ type OpenIDConnectConfiguration struct {
|
||||||
EnforcePKCE string `koanf:"enforce_pkce"`
|
EnforcePKCE string `koanf:"enforce_pkce"`
|
||||||
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
|
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
|
||||||
|
|
||||||
|
CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
|
||||||
|
|
||||||
Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
|
Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
|
||||||
|
type OpenIDConnectCORSConfiguration struct {
|
||||||
|
Endpoints []string `koanf:"endpoints"`
|
||||||
|
AllowedOrigins []url.URL `koanf:"allowed_origins"`
|
||||||
|
|
||||||
|
AllowedOriginsFromClientRedirectURIs bool `koanf:"allowed_origins_from_client_redirect_uris"`
|
||||||
|
}
|
||||||
|
|
||||||
// OpenIDConnectClientConfiguration configuration for an OpenID Connect client.
|
// OpenIDConnectClientConfiguration configuration for an OpenID Connect client.
|
||||||
type OpenIDConnectClientConfiguration struct {
|
type OpenIDConnectClientConfiguration struct {
|
||||||
ID string `koanf:"id"`
|
ID string `koanf:"id"`
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
---
|
||||||
|
default_redirection_url: https://home.example.com:8080/
|
||||||
|
|
||||||
|
server:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 9091
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: debug
|
||||||
|
|
||||||
|
totp:
|
||||||
|
issuer: authelia.com
|
||||||
|
|
||||||
|
duo_api:
|
||||||
|
hostname: api-123456789.example.com
|
||||||
|
integration_key: ABCDEF
|
||||||
|
|
||||||
|
authentication_backend:
|
||||||
|
ldap:
|
||||||
|
url: ldap://127.0.0.1
|
||||||
|
base_dn: dc=example,dc=com
|
||||||
|
username_attribute: uid
|
||||||
|
additional_users_dn: ou=users
|
||||||
|
users_filter: (&({username_attribute}={input})(objectCategory=person)(objectClass=user))
|
||||||
|
additional_groups_dn: ou=groups
|
||||||
|
groups_filter: (&(member={dn})(objectClass=groupOfNames))
|
||||||
|
group_name_attribute: cn
|
||||||
|
mail_attribute: mail
|
||||||
|
user: cn=admin,dc=example,dc=com
|
||||||
|
|
||||||
|
access_control:
|
||||||
|
default_policy: deny
|
||||||
|
|
||||||
|
rules:
|
||||||
|
# Rules applied to everyone
|
||||||
|
- domain: public.example.com
|
||||||
|
policy: bypass
|
||||||
|
|
||||||
|
- domain: secure.example.com
|
||||||
|
policy: one_factor
|
||||||
|
# Network based rule, if not provided any network matches.
|
||||||
|
networks:
|
||||||
|
- 192.168.1.0/24
|
||||||
|
- domain: secure.example.com
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
- domain: [singlefactor.example.com, onefactor.example.com]
|
||||||
|
policy: one_factor
|
||||||
|
|
||||||
|
# Rules applied to 'admins' group
|
||||||
|
- domain: "mx2.mail.example.com"
|
||||||
|
subject: "group:admins"
|
||||||
|
policy: deny
|
||||||
|
- domain: "*.example.com"
|
||||||
|
subject: "group:admins"
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
# Rules applied to 'dev' group
|
||||||
|
- domain: dev.example.com
|
||||||
|
resources:
|
||||||
|
- "^/groups/dev/.*$"
|
||||||
|
subject: "group:dev"
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
# Rules applied to user 'john'
|
||||||
|
- domain: dev.example.com
|
||||||
|
resources:
|
||||||
|
- "^/users/john/.*$"
|
||||||
|
subject: "user:john"
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
# Rules applied to 'dev' group and user 'john'
|
||||||
|
- domain: dev.example.com
|
||||||
|
resources:
|
||||||
|
- "^/deny-all.*$"
|
||||||
|
subject: ["group:dev", "user:john"]
|
||||||
|
policy: deny
|
||||||
|
|
||||||
|
# Rules applied to user 'harry'
|
||||||
|
- domain: dev.example.com
|
||||||
|
resources:
|
||||||
|
- "^/users/harry/.*$"
|
||||||
|
subject: "user:harry"
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
# Rules applied to user 'bob'
|
||||||
|
- domain: "*.mail.example.com"
|
||||||
|
subject: "user:bob"
|
||||||
|
policy: two_factor
|
||||||
|
- domain: "dev.example.com"
|
||||||
|
resources:
|
||||||
|
- "^/users/bob/.*$"
|
||||||
|
subject: "user:bob"
|
||||||
|
policy: two_factor
|
||||||
|
|
||||||
|
session:
|
||||||
|
name: authelia_session
|
||||||
|
expiration: 3600000 # 1 hour
|
||||||
|
inactivity: 300000 # 5 minutes
|
||||||
|
domain: example.com
|
||||||
|
redis:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 6379
|
||||||
|
high_availability:
|
||||||
|
sentinel_name: test
|
||||||
|
|
||||||
|
regulation:
|
||||||
|
max_retries: 3
|
||||||
|
find_time: 120
|
||||||
|
ban_time: 300
|
||||||
|
|
||||||
|
storage:
|
||||||
|
mysql:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 3306
|
||||||
|
database: authelia
|
||||||
|
username: authelia
|
||||||
|
|
||||||
|
notifier:
|
||||||
|
smtp:
|
||||||
|
username: test
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 1025
|
||||||
|
sender: admin@example.com
|
||||||
|
disable_require_tls: true
|
||||||
|
|
||||||
|
identity_providers:
|
||||||
|
oidc:
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- https://google.com
|
||||||
|
- https://example.com
|
||||||
|
...
|
|
@ -124,10 +124,14 @@ const (
|
||||||
errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " +
|
errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " +
|
||||||
"more clients configured"
|
"more clients configured"
|
||||||
errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
|
errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
|
||||||
|
|
||||||
errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " +
|
errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " +
|
||||||
"'public_clients_only' or 'always', but it is configured as '%s'"
|
"'public_clients_only' or 'always', but it is configured as '%s'"
|
||||||
|
|
||||||
|
errFmtOIDCCORSInvalidOrigin = "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value '%s' as it has a %s: origins must only be scheme, hostname, and an optional port"
|
||||||
|
errFmtOIDCCORSInvalidOriginWildcard = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself"
|
||||||
|
errFmtOIDCCORSInvalidOriginWildcardWithClients = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled"
|
||||||
|
errFmtOIDCCORSInvalidEndpoint = "identity_providers: oidc: cors: option 'endpoints' contains an invalid value '%s': must be one of '%s'"
|
||||||
|
|
||||||
errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" +
|
errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" +
|
||||||
"id's must be unique"
|
"id's must be unique"
|
||||||
errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " +
|
errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " +
|
||||||
|
@ -275,6 +279,7 @@ var validOIDCScopes = []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProf
|
||||||
var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}
|
var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}
|
||||||
var validOIDCResponseModes = []string{"form_post", "query", "fragment"}
|
var validOIDCResponseModes = []string{"form_post", "query", "fragment"}
|
||||||
var validOIDCUserinfoAlgorithms = []string{"none", "RS256"}
|
var validOIDCUserinfoAlgorithms = []string{"none", "RS256"}
|
||||||
|
var validOIDCCORSEndpoints = []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint}
|
||||||
|
|
||||||
var reKeyReplacer = regexp.MustCompile(`\[\d+]`)
|
var reKeyReplacer = regexp.MustCompile(`\[\d+]`)
|
||||||
|
|
||||||
|
@ -471,6 +476,9 @@ var ValidKeys = []string{
|
||||||
"identity_providers.oidc.enable_pkce_plain_challenge",
|
"identity_providers.oidc.enable_pkce_plain_challenge",
|
||||||
"identity_providers.oidc.enable_client_debug_messages",
|
"identity_providers.oidc.enable_client_debug_messages",
|
||||||
"identity_providers.oidc.minimum_parameter_entropy",
|
"identity_providers.oidc.minimum_parameter_entropy",
|
||||||
|
"identity_providers.oidc.cors.endpoints",
|
||||||
|
"identity_providers.oidc.cors.allowed_origins",
|
||||||
|
"identity_providers.oidc.cors.enable_origins_from_clients",
|
||||||
"identity_providers.oidc.clients",
|
"identity_providers.oidc.clients",
|
||||||
"identity_providers.oidc.clients[].id",
|
"identity_providers.oidc.clients[].id",
|
||||||
"identity_providers.oidc.clients[].description",
|
"identity_providers.oidc.clients[].description",
|
||||||
|
|
|
@ -49,6 +49,7 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
|
||||||
validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE))
|
validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validateOIDCOptionsCORS(config, validator)
|
||||||
validateOIDCClients(config, validator)
|
validateOIDCClients(config, validator)
|
||||||
|
|
||||||
if len(config.Clients) == 0 {
|
if len(config.Clients) == 0 {
|
||||||
|
@ -57,6 +58,64 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||||
|
validateOIDCOptionsCORSAllowedOrigins(config, validator)
|
||||||
|
|
||||||
|
if config.CORS.AllowedOriginsFromClientRedirectURIs {
|
||||||
|
validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
validateOIDCOptionsCORSEndpoints(config, validator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||||
|
for _, origin := range config.CORS.AllowedOrigins {
|
||||||
|
if origin.String() == "*" {
|
||||||
|
if len(config.CORS.AllowedOrigins) != 1 {
|
||||||
|
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard))
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.CORS.AllowedOriginsFromClientRedirectURIs {
|
||||||
|
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients))
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if origin.Path != "" {
|
||||||
|
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if origin.RawQuery != "" {
|
||||||
|
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
|
||||||
|
for _, client := range config.Clients {
|
||||||
|
for _, redirectURI := range client.RedirectURIs {
|
||||||
|
uri, err := url.Parse(redirectURI)
|
||||||
|
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
origin := utils.OriginFromURL(*uri)
|
||||||
|
|
||||||
|
if !utils.IsURLInSlice(origin, config.CORS.AllowedOrigins) {
|
||||||
|
config.CORS.AllowedOrigins = append(config.CORS.AllowedOrigins, origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||||
|
for _, endpoint := range config.CORS.Endpoints {
|
||||||
|
if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) {
|
||||||
|
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||||
invalidID, duplicateIDs := false, false
|
invalidID, duplicateIDs := false, false
|
||||||
|
|
||||||
|
@ -97,7 +156,6 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s
|
||||||
validateOIDCClientResponseTypes(c, config, validator)
|
validateOIDCClientResponseTypes(c, config, validator)
|
||||||
validateOIDCClientResponseModes(c, config, validator)
|
validateOIDCClientResponseModes(c, config, validator)
|
||||||
validateOIDDClientUserinfoAlgorithm(c, config, validator)
|
validateOIDDClientUserinfoAlgorithm(c, config, validator)
|
||||||
|
|
||||||
validateOIDCClientRedirectURIs(client, validator)
|
validateOIDCClientRedirectURIs(client, validator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
|
"github.com/authelia/authelia/v4/internal/oidc"
|
||||||
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
|
func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
|
||||||
|
@ -29,6 +31,54 @@ func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
|
||||||
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldNotRaiseErrorWhenCORSEndpointsValid(t *testing.T) {
|
||||||
|
validator := schema.NewStructValidator()
|
||||||
|
config := &schema.IdentityProvidersConfiguration{
|
||||||
|
OIDC: &schema.OpenIDConnectConfiguration{
|
||||||
|
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||||
|
IssuerPrivateKey: "key-material",
|
||||||
|
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||||
|
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint},
|
||||||
|
},
|
||||||
|
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||||
|
{
|
||||||
|
ID: "example",
|
||||||
|
Secret: "example",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidateIdentityProviders(config, validator)
|
||||||
|
|
||||||
|
assert.Len(t, validator.Errors(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
|
||||||
|
validator := schema.NewStructValidator()
|
||||||
|
config := &schema.IdentityProvidersConfiguration{
|
||||||
|
OIDC: &schema.OpenIDConnectConfiguration{
|
||||||
|
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||||
|
IssuerPrivateKey: "key-material",
|
||||||
|
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||||
|
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint, "invalid_endpoint"},
|
||||||
|
},
|
||||||
|
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||||
|
{
|
||||||
|
ID: "example",
|
||||||
|
Secret: "example",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidateIdentityProviders(config, validator)
|
||||||
|
|
||||||
|
require.Len(t, validator.Errors(), 1)
|
||||||
|
|
||||||
|
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'")
|
||||||
|
}
|
||||||
|
|
||||||
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
||||||
validator := schema.NewStructValidator()
|
validator := schema.NewStructValidator()
|
||||||
config := &schema.IdentityProvidersConfiguration{
|
config := &schema.IdentityProvidersConfiguration{
|
||||||
|
@ -47,7 +97,44 @@ func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
||||||
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseErrorWhenOIDCServerIssuerPrivateKeyPathInvalid(t *testing.T) {
|
func TestShouldRaiseErrorWhenOIDCCORSOriginsHasInvalidValues(t *testing.T) {
|
||||||
|
validator := schema.NewStructValidator()
|
||||||
|
|
||||||
|
config := &schema.IdentityProvidersConfiguration{
|
||||||
|
OIDC: &schema.OpenIDConnectConfiguration{
|
||||||
|
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||||
|
IssuerPrivateKey: "key-material",
|
||||||
|
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||||
|
AllowedOrigins: utils.URLsFromStringSlice([]string{"https://example.com/", "https://site.example.com/subpath", "https://site.example.com?example=true", "*"}),
|
||||||
|
AllowedOriginsFromClientRedirectURIs: true,
|
||||||
|
},
|
||||||
|
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||||
|
{
|
||||||
|
ID: "myclient",
|
||||||
|
Secret: "jk12nb3klqwmnelqkwenm",
|
||||||
|
Policy: "two_factor",
|
||||||
|
RedirectURIs: []string{"https://example.com/oauth2_callback", "https://localhost:566/callback", "http://an.example.com/callback", "file://a/file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidateIdentityProviders(config, validator)
|
||||||
|
|
||||||
|
require.Len(t, validator.Errors(), 6)
|
||||||
|
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://example.com/' as it has a path: origins must only be scheme, hostname, and an optional port")
|
||||||
|
assert.EqualError(t, validator.Errors()[1], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com/subpath' as it has a path: origins must only be scheme, hostname, and an optional port")
|
||||||
|
assert.EqualError(t, validator.Errors()[2], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com?example=true' as it has a query string: origins must only be scheme, hostname, and an optional port")
|
||||||
|
assert.EqualError(t, validator.Errors()[3], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself")
|
||||||
|
assert.EqualError(t, validator.Errors()[4], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled")
|
||||||
|
assert.EqualError(t, validator.Errors()[5], "identity_providers: oidc: client 'myclient': option 'redirect_uris' has an invalid value: redirect uri 'file://a/file' must have a scheme of 'http' or 'https' but 'file' is configured")
|
||||||
|
|
||||||
|
require.Len(t, config.OIDC.CORS.AllowedOrigins, 6)
|
||||||
|
assert.Equal(t, "*", config.OIDC.CORS.AllowedOrigins[3].String())
|
||||||
|
assert.Equal(t, "https://example.com", config.OIDC.CORS.AllowedOrigins[4].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldRaiseErrorWhenOIDCServerNoClients(t *testing.T) {
|
||||||
validator := schema.NewStructValidator()
|
validator := schema.NewStructValidator()
|
||||||
config := &schema.IdentityProvidersConfiguration{
|
config := &schema.IdentityProvidersConfiguration{
|
||||||
OIDC: &schema.OpenIDConnectConfiguration{
|
OIDC: &schema.OpenIDConnectConfiguration{
|
||||||
|
|
|
@ -72,16 +72,6 @@ const (
|
||||||
auth = "auth"
|
auth = "auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OIDC constants.
|
|
||||||
const (
|
|
||||||
pathLegacyOpenIDConnectAuthorization = "/api/oidc/authorize"
|
|
||||||
pathLegacyOpenIDConnectIntrospection = "/api/oidc/introspect"
|
|
||||||
pathLegacyOpenIDConnectRevocation = "/api/oidc/revoke"
|
|
||||||
|
|
||||||
// Note: If you change this const you must also do so in the frontend at web/src/services/Api.ts.
|
|
||||||
pathOpenIDConnectConsent = "/api/oidc/consent"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
accept = "accept"
|
accept = "accept"
|
||||||
reject = "reject"
|
reject = "reject"
|
||||||
|
|
|
@ -6,7 +6,8 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcJWKs(ctx *middlewares.AutheliaCtx) {
|
// JSONWebKeySetGET returns the JSON Web Key Set. Used in OAuth 2.0 and OpenID Connect 1.0.
|
||||||
|
func JSONWebKeySetGET(ctx *middlewares.AutheliaCtx) {
|
||||||
ctx.SetContentType("application/json")
|
ctx.SetContentType("application/json")
|
||||||
|
|
||||||
if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil {
|
if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil {
|
|
@ -9,7 +9,10 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/oidc"
|
"github.com/authelia/authelia/v4/internal/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcIntrospection(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
// OAuthIntrospectionPOST handles POST requests to the OAuth 2.0 Introspection endpoint.
|
||||||
|
//
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc7662
|
||||||
|
func OAuthIntrospectionPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||||
var (
|
var (
|
||||||
responder fosite.IntrospectionResponder
|
responder fosite.IntrospectionResponder
|
||||||
err error
|
err error
|
|
@ -8,7 +8,10 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcRevocation(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
// OAuthRevocationPOST handles POST requests to the OAuth 2.0 Revocation endpoint.
|
||||||
|
//
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc7009
|
||||||
|
func OAuthRevocationPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil {
|
if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil {
|
|
@ -16,7 +16,10 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/session"
|
"github.com/authelia/authelia/v4/internal/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
|
// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||||
|
func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
|
||||||
var (
|
var (
|
||||||
requester fosite.AuthorizeRequester
|
requester fosite.AuthorizeRequester
|
||||||
responder fosite.AuthorizeResponder
|
responder fosite.AuthorizeResponder
|
||||||
|
|
|
@ -7,7 +7,8 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcConsent(ctx *middlewares.AutheliaCtx) {
|
// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
|
||||||
|
func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
|
||||||
userSession := ctx.GetSession()
|
userSession := ctx.GetSession()
|
||||||
|
|
||||||
if userSession.OIDCWorkflowSession == nil {
|
if userSession.OIDCWorkflowSession == nil {
|
||||||
|
@ -39,7 +40,8 @@ func oidcConsent(ctx *middlewares.AutheliaCtx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func oidcConsentPOST(ctx *middlewares.AutheliaCtx) {
|
// OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
|
||||||
|
func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
|
||||||
userSession := ctx.GetSession()
|
userSession := ctx.GetSession()
|
||||||
|
|
||||||
if userSession.OIDCWorkflowSession == nil {
|
if userSession.OIDCWorkflowSession == nil {
|
||||||
|
|
|
@ -9,7 +9,10 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/oidc"
|
"github.com/authelia/authelia/v4/internal/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcToken(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
// OpenIDConnectTokenPOST handles POST requests to the OpenID Connect 1.0 Token endpoint.
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||||
|
func OpenIDConnectTokenPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||||
var (
|
var (
|
||||||
requester fosite.AccessRequester
|
requester fosite.AccessRequester
|
||||||
responder fosite.AccessResponder
|
responder fosite.AccessResponder
|
||||||
|
|
|
@ -14,7 +14,10 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/oidc"
|
"github.com/authelia/authelia/v4/internal/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
// OpenIDConnectUserinfo handles GET/POST requests to the OpenID Connect 1.0 UserInfo endpoint.
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||||
|
func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||||
var (
|
var (
|
||||||
tokenType fosite.TokenType
|
tokenType fosite.TokenType
|
||||||
requester fosite.AccessRequester
|
requester fosite.AccessRequester
|
||||||
|
@ -97,7 +100,7 @@ func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *htt
|
||||||
var jti uuid.UUID
|
var jti uuid.UUID
|
||||||
|
|
||||||
if jti, err = uuid.NewRandom(); err != nil {
|
if jti, err = uuid.NewRandom(); err != nil {
|
||||||
ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JWT ID."))
|
ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JTI."))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,13 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
)
|
)
|
||||||
|
|
||||||
func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
|
// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||||
|
// OpenID Connect Discovery 1.0 metadata.
|
||||||
|
//
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc5785
|
||||||
|
//
|
||||||
|
// https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||||
|
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||||
issuer, err := ctx.ExternalRootURL()
|
issuer, err := ctx.ExternalRootURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
||||||
|
@ -30,7 +36,13 @@ func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func wellKnownOAuthAuthorizationServerGET(ctx *middlewares.AutheliaCtx) {
|
// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||||
|
// OAuth 2.0 Authorization Server Metadata (RFC8414).
|
||||||
|
//
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc5785
|
||||||
|
//
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc8414
|
||||||
|
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||||
issuer, err := ctx.ExternalRootURL()
|
issuer, err := ctx.ExternalRootURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/fasthttp/router"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
|
||||||
"github.com/authelia/authelia/v4/internal/oidc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
|
|
||||||
func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
|
|
||||||
// TODO: Add OPTIONS handler.
|
|
||||||
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
|
|
||||||
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
|
|
||||||
|
|
||||||
router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))
|
|
||||||
|
|
||||||
router.POST(pathOpenIDConnectConsent, middleware(oidcConsentPOST))
|
|
||||||
|
|
||||||
router.GET(oidc.JWKsPath, middleware(oidcJWKs))
|
|
||||||
|
|
||||||
router.GET(oidc.AuthorizationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
|
|
||||||
router.GET(pathLegacyOpenIDConnectAuthorization, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
|
|
||||||
|
|
||||||
// TODO: Add OPTIONS handler.
|
|
||||||
router.POST(oidc.TokenPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcToken)))
|
|
||||||
|
|
||||||
router.POST(oidc.IntrospectionPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
|
|
||||||
router.GET(pathLegacyOpenIDConnectIntrospection, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
|
|
||||||
|
|
||||||
router.GET(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
|
|
||||||
router.POST(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
|
|
||||||
|
|
||||||
// TODO: Add OPTIONS handler.
|
|
||||||
router.POST(oidc.RevocationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
|
|
||||||
router.POST(pathLegacyOpenIDConnectRevocation, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
|
|
||||||
}
|
|
|
@ -7,18 +7,22 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
headerAccept = []byte(fasthttp.HeaderAccept)
|
||||||
|
headerContentLength = []byte(fasthttp.HeaderContentLength)
|
||||||
|
|
||||||
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
|
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
|
||||||
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
|
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
|
||||||
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
|
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
|
||||||
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
|
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
|
||||||
headerAccept = []byte(fasthttp.HeaderAccept)
|
|
||||||
|
|
||||||
headerXForwardedURI = []byte("X-Forwarded-URI")
|
headerXForwardedURI = []byte("X-Forwarded-URI")
|
||||||
headerXOriginalURL = []byte("X-Original-URL")
|
headerXOriginalURL = []byte("X-Original-URL")
|
||||||
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
||||||
|
|
||||||
headerVary = []byte(fasthttp.HeaderVary)
|
headerVary = []byte(fasthttp.HeaderVary)
|
||||||
|
headerAllow = []byte(fasthttp.HeaderAllow)
|
||||||
headerOrigin = []byte(fasthttp.HeaderOrigin)
|
headerOrigin = []byte(fasthttp.HeaderOrigin)
|
||||||
|
|
||||||
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
|
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
|
||||||
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
|
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
|
||||||
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
|
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
|
||||||
|
@ -30,8 +34,12 @@ var (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
headerValueFalse = []byte("false")
|
headerValueFalse = []byte("false")
|
||||||
|
headerValueTrue = []byte("true")
|
||||||
headerValueMaxAge = []byte("100")
|
headerValueMaxAge = []byte("100")
|
||||||
headerValueVary = []byte("Accept-Encoding, Origin")
|
headerValueVary = []byte("Accept-Encoding, Origin")
|
||||||
|
headerValueVaryWildcard = []byte("Accept-Encoding")
|
||||||
|
headerValueOriginWildcard = []byte("*")
|
||||||
|
headerValueZero = []byte("0")
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -40,6 +48,8 @@ var (
|
||||||
|
|
||||||
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
||||||
UserValueKeyBaseURL = []byte("base_url")
|
UserValueKeyBaseURL = []byte("base_url")
|
||||||
|
|
||||||
|
headerSeparator = []byte(", ")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -1,53 +1,347 @@
|
||||||
package middlewares
|
package middlewares
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well
|
// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
|
||||||
// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied
|
// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
|
||||||
// to both Accept-Encoding and Origin. It grants the GET Request Method only.
|
// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
|
||||||
func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler {
|
// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
|
||||||
return func(ctx *AutheliaCtx) {
|
// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
|
||||||
if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil {
|
// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
|
||||||
corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin)
|
// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
|
||||||
|
// the Access-Control-Max-Age header to 100.
|
||||||
|
//
|
||||||
|
// These behaviours can be overridden by the With methods on the returned policy.
|
||||||
|
func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
|
||||||
|
return &CORSPolicyBuilder{
|
||||||
|
enabled: true,
|
||||||
|
maxAge: 100,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
|
||||||
|
// configured. It aims to simplify CORS configurations.
|
||||||
|
type CORSPolicyBuilder struct {
|
||||||
|
enabled bool
|
||||||
|
varyOnly bool
|
||||||
|
varySet bool
|
||||||
|
methods []string
|
||||||
|
headers []string
|
||||||
|
origins []string
|
||||||
|
credentials bool
|
||||||
|
vary []string
|
||||||
|
maxAge int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
|
||||||
|
func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
|
||||||
|
policy = &CORSPolicy{
|
||||||
|
enabled: b.enabled,
|
||||||
|
varyOnly: b.varyOnly,
|
||||||
|
credentials: []byte(strconv.FormatBool(b.credentials)),
|
||||||
|
origins: b.buildOrigins(),
|
||||||
|
headers: b.buildHeaders(),
|
||||||
|
vary: b.buildVary(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b.methods) != 0 {
|
||||||
|
policy.methods = []byte(strings.Join(b.methods, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.maxAge <= 0 {
|
||||||
|
policy.maxAge = headerValueMaxAge
|
||||||
|
} else {
|
||||||
|
policy.maxAge = []byte(strconv.Itoa(b.maxAge))
|
||||||
|
}
|
||||||
|
|
||||||
|
return policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
|
||||||
|
if len(b.origins) != 0 {
|
||||||
|
if len(b.origins) == 1 && b.origins[0] == "*" {
|
||||||
|
origins = append(origins, []byte(b.origins[0]))
|
||||||
|
} else {
|
||||||
|
for _, origin := range b.origins {
|
||||||
|
origins = append(origins, []byte(origin))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return origins
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
|
||||||
|
if len(b.headers) != 0 {
|
||||||
|
h := b.headers
|
||||||
|
|
||||||
|
if b.credentials {
|
||||||
|
if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
|
||||||
|
h = append(h, fasthttp.HeaderCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
|
||||||
|
h = append(h, fasthttp.HeaderAuthorization)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
|
||||||
|
h = append(h, fasthttp.HeaderProxyAuthorization)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b CORSPolicyBuilder) buildVary() (vary []byte) {
|
||||||
|
if b.varySet {
|
||||||
|
if len(b.vary) != 0 {
|
||||||
|
vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(b.origins) == 1 && b.origins[0] == "*" {
|
||||||
|
vary = headerValueVaryWildcard
|
||||||
|
} else {
|
||||||
|
vary = headerValueVary
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vary
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
|
||||||
|
// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
|
||||||
|
// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
|
||||||
|
// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
|
||||||
|
// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
|
||||||
|
// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
|
||||||
|
func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
|
||||||
|
b.enabled = enabled
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
|
||||||
|
// that value.
|
||||||
|
func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
|
||||||
|
b.methods = methods
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
|
||||||
|
func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
|
||||||
|
b.origins = origins
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
|
||||||
|
func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
|
||||||
|
b.headers = headers
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
|
||||||
|
func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
|
||||||
|
b.credentials = allow
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithVary takes a list of header strings and alters the default Vary header.
|
||||||
|
func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
|
||||||
|
b.vary = headers
|
||||||
|
b.varySet = true
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithVaryOnly just adds the Vary header.
|
||||||
|
func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
|
||||||
|
b.varyOnly = varyOnly
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
|
||||||
|
func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
|
||||||
|
b.maxAge = age
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSPolicy is a middleware that handles adding CORS headers.
|
||||||
|
type CORSPolicy struct {
|
||||||
|
enabled bool
|
||||||
|
varyOnly bool
|
||||||
|
methods []byte
|
||||||
|
headers []byte
|
||||||
|
origins [][]byte
|
||||||
|
credentials []byte
|
||||||
|
vary []byte
|
||||||
|
maxAge []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
|
||||||
|
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
||||||
|
func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||||
|
p.handleOPTIONS(ctx)
|
||||||
|
p.handle(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
|
||||||
|
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
||||||
|
func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||||
|
p.handleOPTIONS(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
|
||||||
|
func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
|
||||||
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
p.handle(ctx)
|
||||||
|
|
||||||
next(ctx)
|
next(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) {
|
func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
|
||||||
originURL, err := url.Parse(string(origin))
|
if !p.enabled {
|
||||||
if err != nil || originURL.Scheme != "https" {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Header.SetBytesKV(headerVary, headerValueVary)
|
p.handleVary(ctx)
|
||||||
resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
|
|
||||||
resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
|
|
||||||
resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
|
|
||||||
|
|
||||||
if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
if !p.varyOnly {
|
||||||
|
p.handleCORS(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.Response.ResetBody()
|
||||||
|
|
||||||
|
/* The OPTIONS method should not return a 204 as per the following specifications when read together:
|
||||||
|
|
||||||
|
RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
|
||||||
|
A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
|
||||||
|
the response.
|
||||||
|
|
||||||
|
RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
|
||||||
|
A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
|
||||||
|
or 204 (No Content).
|
||||||
|
*/
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)
|
||||||
|
|
||||||
|
if len(p.methods) != 0 {
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
|
||||||
|
if len(p.vary) != 0 {
|
||||||
|
ctx.Response.Header.SetBytesKV(headerVary, p.vary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
|
||||||
|
var (
|
||||||
|
originURL *url.URL
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
||||||
|
|
||||||
|
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
||||||
|
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var allowedOrigin []byte
|
||||||
|
|
||||||
|
switch len(p.origins) {
|
||||||
|
case 0:
|
||||||
|
allowedOrigin = origin
|
||||||
|
default:
|
||||||
|
for i := 0; i < len(p.origins); i++ {
|
||||||
|
if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
|
||||||
|
allowedOrigin = headerValueOriginWildcard
|
||||||
|
} else if bytes.Equal(p.origins[i], origin) {
|
||||||
|
allowedOrigin = origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allowedOrigin) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)
|
||||||
|
|
||||||
|
if len(p.credentials) != 0 {
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.maxAge) != 0 {
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.handleAllowedHeaders(ctx)
|
||||||
|
p.handleAllowedMethods(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
|
||||||
|
switch len(p.methods) {
|
||||||
|
case 0:
|
||||||
|
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
||||||
|
if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
|
||||||
|
switch len(p.headers) {
|
||||||
|
case 0:
|
||||||
|
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
||||||
|
if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
||||||
requestedHeaders := strings.Split(string(headers), ",")
|
requestedHeaders := strings.Split(string(headers), ",")
|
||||||
allowHeaders := make([]string, len(requestedHeaders))
|
allowHeaders := make([]string, 0, len(requestedHeaders))
|
||||||
|
|
||||||
for i, header := range requestedHeaders {
|
for i := 0; i < len(requestedHeaders); i++ {
|
||||||
headerTrimmed := strings.Trim(header, " ")
|
headerTrimmed := strings.Trim(requestedHeaders[i], " ")
|
||||||
if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
|
|
||||||
allowHeaders[i] = headerTrimmed
|
if headerTrimmed == "*" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(p.credentials, headerValueTrue) ||
|
||||||
|
(!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
|
||||||
|
!strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
|
||||||
|
!strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
|
||||||
|
allowHeaders = append(allowHeaders, headerTrimmed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allowHeaders) != 0 {
|
if len(allowHeaders) != 0 {
|
||||||
resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
|
||||||
resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,61 +5,587 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
|
func TestNewCORSMiddleware(t *testing.T) {
|
||||||
req := fasthttp.AcquireRequest()
|
cors := NewCORSPolicyBuilder()
|
||||||
resp := fasthttp.Response{}
|
|
||||||
|
assert.Equal(t, 100, cors.maxAge)
|
||||||
|
assert.Equal(t, false, cors.credentials)
|
||||||
|
|
||||||
|
assert.Nil(t, cors.methods)
|
||||||
|
assert.Nil(t, cors.origins)
|
||||||
|
assert.Nil(t, cors.headers)
|
||||||
|
assert.Nil(t, cors.vary)
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
assert.False(t, cors.varySet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithEnabled(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.True(t, cors.enabled)
|
||||||
|
|
||||||
|
cors.WithEnabled(false)
|
||||||
|
assert.False(t, cors.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithVary(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.vary)
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
assert.False(t, cors.varySet)
|
||||||
|
|
||||||
|
cors.WithVary()
|
||||||
|
assert.Nil(t, cors.vary)
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
assert.True(t, cors.varySet)
|
||||||
|
|
||||||
|
cors.WithVary("Origin", "Example", "Test")
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"Origin", "Example", "Test"}, cors.vary)
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
assert.True(t, cors.varySet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithAllowedMethods(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.methods)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("GET")
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"GET"}, cors.methods)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("POST", "PATCH")
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"POST", "PATCH"}, cors.methods)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.methods)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithAllowedOrigins(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.origins)
|
||||||
|
|
||||||
|
cors.WithAllowedOrigins("https://google.com", "http://localhost")
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"https://google.com", "http://localhost"}, cors.origins)
|
||||||
|
|
||||||
|
cors.WithAllowedOrigins()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.origins)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithAllowedHeaders(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.headers)
|
||||||
|
|
||||||
|
cors.WithAllowedHeaders("Example", "Another")
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"Example", "Another"}, cors.headers)
|
||||||
|
|
||||||
|
cors.WithAllowedHeaders()
|
||||||
|
|
||||||
|
assert.Nil(t, cors.headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithAllowCredentials(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Equal(t, false, cors.credentials)
|
||||||
|
|
||||||
|
cors.WithAllowCredentials(false)
|
||||||
|
|
||||||
|
assert.Equal(t, false, cors.credentials)
|
||||||
|
|
||||||
|
cors.WithAllowCredentials(true)
|
||||||
|
|
||||||
|
assert.Equal(t, true, cors.credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithVaryOnly(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
|
||||||
|
cors.WithVaryOnly(false)
|
||||||
|
|
||||||
|
assert.False(t, cors.varyOnly)
|
||||||
|
|
||||||
|
cors.WithVaryOnly(true)
|
||||||
|
|
||||||
|
cors.WithVaryOnly(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithMaxAge(t *testing.T) {
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
assert.Equal(t, 100, cors.maxAge)
|
||||||
|
|
||||||
|
cors.WithMaxAge(20)
|
||||||
|
|
||||||
|
assert.Equal(t, 20, cors.maxAge)
|
||||||
|
|
||||||
|
cors.WithMaxAge(0)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, cors.maxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONS(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
origin := []byte("https://myapp.example.com")
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
cors := NewCORSPolicyBuilder()
|
||||||
|
policy := cors.Build()
|
||||||
|
|
||||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
policy.HandleOPTIONS(ctx)
|
||||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
|
||||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOnlyOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithEnabled(false)
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONS_WithoutOrigin(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedOrigins(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
cors.WithAllowedOrigins("https://myapp.example.com")
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowedOrigins("https://anotherapp.example.com")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowedOrigins("*")
|
||||||
|
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_WithAllowedOrigins_DoesntOverrideVary(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
cors.WithVary("Accept-Encoding", "Origin", "Test")
|
||||||
|
cors.WithAllowedOrigins("*")
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin, Test"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONSWithVaryOnly(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
cors.WithVaryOnly(true)
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedHeaders(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
cors.WithAllowedHeaders("Example", "Test")
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
|
||||||
|
ctx = newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors.WithAllowCredentials(true)
|
||||||
|
|
||||||
|
policy = cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueTrue, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("Example, Test, Cookie, Authorization, Proxy-Authorization"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCORSPolicyBuilder_HandleOPTIONS_ShouldNotAllowWildcardInRequestedHeaders(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "*")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.HandleOPTIONS(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
policy.handle(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
|
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
|
||||||
req := fasthttp.AcquireRequest()
|
ctx := newFastHTTPRequestCtx()
|
||||||
resp := fasthttp.Response{}
|
|
||||||
|
|
||||||
origin := []byte("https://myapp.example.com")
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||||
|
|
||||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
cors := NewCORSPolicyBuilder()
|
||||||
|
|
||||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
policy := cors.Build()
|
||||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
policy.handle(ctx)
|
||||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
|
||||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
|
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
|
||||||
req := fasthttp.AcquireRequest()
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
resp := fasthttp.Response{}
|
|
||||||
|
|
||||||
origin := []byte("http://myapp.example.com")
|
origin := []byte("http://myapp.example.com")
|
||||||
|
|
||||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||||
|
|
||||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
cors := NewCORSPolicyBuilder().WithVary()
|
||||||
|
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary))
|
policy := cors.Build()
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
policy.handle(ctx)
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge))
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CORSMiddleware_AsMiddleware(t *testing.T) {
|
||||||
|
ctx := newFastHTTPRequestCtx()
|
||||||
|
|
||||||
|
origin := []byte("https://myapp.example.com")
|
||||||
|
|
||||||
|
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||||
|
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||||
|
|
||||||
|
autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{})
|
||||||
|
|
||||||
|
cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS")
|
||||||
|
|
||||||
|
policy := cors.Build()
|
||||||
|
|
||||||
|
route := policy.Middleware(autheliaMiddleware(testNilHandler))
|
||||||
|
|
||||||
|
route(ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||||
|
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||||
|
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||||
|
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||||
|
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testNilHandler(_ *AutheliaCtx) {}
|
||||||
|
|
||||||
|
func newFastHTTPRequestCtx() (ctx *fasthttp.RequestCtx) {
|
||||||
|
return &fasthttp.RequestCtx{
|
||||||
|
Request: fasthttp.Request{},
|
||||||
|
Response: fasthttp.Response{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,17 +19,28 @@ const (
|
||||||
ClaimEmailAlts = "alt_emails"
|
ClaimEmailAlts = "alt_emails"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Endpoints.
|
||||||
|
const (
|
||||||
|
AuthorizationEndpoint = "authorization"
|
||||||
|
TokenEndpoint = "token"
|
||||||
|
UserinfoEndpoint = "userinfo"
|
||||||
|
IntrospectionEndpoint = "introspection"
|
||||||
|
RevocationEndpoint = "revocation"
|
||||||
|
)
|
||||||
|
|
||||||
// Paths.
|
// Paths.
|
||||||
const (
|
const (
|
||||||
WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration"
|
WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration"
|
||||||
WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server"
|
WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server"
|
||||||
|
JWKsPath = "/jwks.json"
|
||||||
|
|
||||||
JWKsPath = "/api/oidc/jwks"
|
RootPath = "/api/oidc"
|
||||||
AuthorizationPath = "/api/oidc/authorization"
|
|
||||||
TokenPath = "/api/oidc/token" //nolint:gosec // This is not a hard coded credential, it's a path.
|
AuthorizationPath = RootPath + "/" + AuthorizationEndpoint
|
||||||
IntrospectionPath = "/api/oidc/introspection"
|
TokenPath = RootPath + "/" + TokenEndpoint
|
||||||
RevocationPath = "/api/oidc/revocation"
|
UserinfoPath = RootPath + "/" + UserinfoEndpoint
|
||||||
UserinfoPath = "/api/oidc/userinfo"
|
IntrospectionPath = RootPath + "/" + IntrospectionEndpoint
|
||||||
|
RevocationPath = RootPath + "/" + RevocationEndpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176
|
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176
|
||||||
|
|
|
@ -87,7 +87,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
|
||||||
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
|
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
|
||||||
|
|
||||||
assert.Equal(t, "https://example.com", disco.Issuer)
|
assert.Equal(t, "https://example.com", disco.Issuer)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
|
assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint)
|
||||||
|
@ -173,7 +173,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
|
||||||
disco := provider.GetOAuth2WellKnownConfiguration("https://example.com")
|
disco := provider.GetOAuth2WellKnownConfiguration("https://example.com")
|
||||||
|
|
||||||
assert.Equal(t, "https://example.com", disco.Issuer)
|
assert.Equal(t, "https://example.com", disco.Issuer)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
|
assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
||||||
assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint)
|
assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint)
|
||||||
|
|
|
@ -37,16 +37,17 @@ var (
|
||||||
{name: "/api", prefix: "/api/"},
|
{name: "/api", prefix: "/api/"},
|
||||||
{name: "/.well-known", prefix: "/.well-known/"},
|
{name: "/.well-known", prefix: "/.well-known/"},
|
||||||
{name: "/static", prefix: "/static/"},
|
{name: "/static", prefix: "/static/"},
|
||||||
|
{name: "/locales", prefix: "/locales/"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const schemeHTTP = "http"
|
|
||||||
const schemeHTTPS = "https"
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dev = "dev"
|
dev = "dev"
|
||||||
f = "false"
|
f = "false"
|
||||||
t = "true"
|
t = "true"
|
||||||
|
localhost = "localhost"
|
||||||
|
schemeHTTP = "http"
|
||||||
|
schemeHTTPS = "https"
|
||||||
)
|
)
|
||||||
|
|
||||||
const healthCheckEnv = `# Written by Authelia Process
|
const healthCheckEnv = `# Written by Authelia Process
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Replacement for the default error handler in fasthttp.
|
|
||||||
func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) {
|
|
||||||
logger := logging.Logger()
|
|
||||||
|
|
||||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
|
||||||
// Note: Getting X-Forwarded-For or Request URI is impossible for ths error.
|
|
||||||
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
|
|
||||||
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
|
|
||||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
|
||||||
// TODO: Add X-Forwarded-For Check here.
|
|
||||||
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
|
|
||||||
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
|
|
||||||
} else {
|
|
||||||
// TODO: Add X-Forwarded-For Check here.
|
|
||||||
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
|
||||||
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/handlers"
|
|
||||||
)
|
|
||||||
|
|
||||||
func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
|
||||||
path := strings.ToLower(string(ctx.Path()))
|
|
||||||
|
|
||||||
for i := 0; i < len(httpServerDirs); i++ {
|
|
||||||
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
|
|
||||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/handlers"
|
||||||
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Replacement for the default error handler in fasthttp.
|
||||||
|
func handlerErrors(ctx *fasthttp.RequestCtx, err error) {
|
||||||
|
logger := logging.Logger()
|
||||||
|
|
||||||
|
switch e := err.(type) {
|
||||||
|
case *fasthttp.ErrSmallBuffer:
|
||||||
|
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||||
|
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||||
|
case *net.OpError:
|
||||||
|
if e.Timeout() {
|
||||||
|
// TODO: Add X-Forwarded-For Check here.
|
||||||
|
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
|
||||||
|
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
|
||||||
|
} else {
|
||||||
|
// TODO: Add X-Forwarded-For Check here.
|
||||||
|
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
||||||
|
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// TODO: Add X-Forwarded-For Check here.
|
||||||
|
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
||||||
|
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
path := strings.ToLower(string(ctx.Path()))
|
||||||
|
|
||||||
|
for i := 0; i < len(httpServerDirs); i++ {
|
||||||
|
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
|
||||||
|
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlerMethodNotAllowed(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
|
||||||
|
}
|
|
@ -1,11 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
|
||||||
)
|
|
||||||
|
|
||||||
func handleOPTIONS(ctx *middlewares.AutheliaCtx) {
|
|
||||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
|
||||||
}
|
|
|
@ -19,10 +19,12 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/handlers"
|
"github.com/authelia/authelia/v4/internal/handlers"
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
|
"github.com/authelia/authelia/v4/internal/oidc"
|
||||||
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: move to its own file and rename configuration -> config.
|
||||||
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||||
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
|
||||||
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled)
|
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled)
|
||||||
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
|
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
|
||||||
|
|
||||||
|
@ -33,37 +35,49 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
||||||
duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment)
|
duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment)
|
||||||
}
|
}
|
||||||
|
|
||||||
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
|
||||||
handlerLocales := newLocalesEmbeddedHandler()
|
|
||||||
|
|
||||||
https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != ""
|
https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != ""
|
||||||
|
|
||||||
serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||||
serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||||
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||||
|
|
||||||
|
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
||||||
|
handlerLocales := newLocalesEmbeddedHandler()
|
||||||
|
|
||||||
|
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
||||||
|
|
||||||
|
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowedMethods("OPTIONS", "GET").
|
||||||
|
WithAllowedOrigins("*").
|
||||||
|
Build()
|
||||||
|
|
||||||
r := router.New()
|
r := router.New()
|
||||||
|
|
||||||
|
// Static Assets.
|
||||||
r.GET("/", autheliaMiddleware(serveIndexHandler))
|
r.GET("/", autheliaMiddleware(serveIndexHandler))
|
||||||
r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
|
|
||||||
|
|
||||||
for _, f := range rootFiles {
|
for _, f := range rootFiles {
|
||||||
r.GET("/"+f, handlerPublicHTML)
|
r.GET("/"+f, handlerPublicHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
|
|
||||||
r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
|
|
||||||
|
|
||||||
for _, file := range swaggerFiles {
|
|
||||||
r.GET("/api/"+file, handlerPublicHTML)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML))
|
r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML))
|
||||||
r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML))
|
r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML))
|
||||||
r.GET("/static/{filepath:*}", handlerPublicHTML)
|
r.GET("/static/{filepath:*}", handlerPublicHTML)
|
||||||
|
|
||||||
|
// Locales.
|
||||||
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
||||||
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
||||||
|
|
||||||
|
// Swagger.
|
||||||
|
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
|
||||||
|
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
r.GET("/api/"+apiFile, policyCORSPublicGET.Middleware(autheliaMiddleware(serveSwaggerAPIHandler)))
|
||||||
|
r.OPTIONS("/api/"+apiFile, policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
|
||||||
|
for _, file := range swaggerFiles {
|
||||||
|
r.GET("/api/"+file, handlerPublicHTML)
|
||||||
|
}
|
||||||
|
|
||||||
r.GET("/api/health", autheliaMiddleware(handlers.HealthGet))
|
r.GET("/api/health", autheliaMiddleware(handlers.HealthGet))
|
||||||
r.GET("/api/state", autheliaMiddleware(handlers.StateGet))
|
r.GET("/api/state", autheliaMiddleware(handlers.StateGet))
|
||||||
|
|
||||||
|
@ -161,22 +175,98 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
||||||
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
|
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.NotFound = handleNotFound(autheliaMiddleware(serveIndexHandler))
|
if providers.OpenIDConnect.Fosite != nil {
|
||||||
|
r.GET("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentGET))
|
||||||
|
r.POST("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentPOST))
|
||||||
|
|
||||||
|
allowedOrigins := utils.StringSliceFromURLs(configuration.IdentityProviders.OIDC.CORS.AllowedOrigins)
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OpenIDConnectConfigurationWellKnownGET)))
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OAuthAuthorizationServerWellKnownGET)))
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
|
||||||
|
|
||||||
|
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||||
|
r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
|
||||||
|
|
||||||
|
policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowedMethods("OPTIONS", "GET").
|
||||||
|
WithAllowedOrigins(allowedOrigins...).
|
||||||
|
WithEnabled(utils.IsStringInSlice(oidc.AuthorizationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS)
|
||||||
|
r.GET(oidc.AuthorizationPath, autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
|
||||||
|
|
||||||
|
// TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
|
||||||
|
r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
|
||||||
|
r.GET("/api/oidc/authorize", autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
|
||||||
|
|
||||||
|
policyCORSToken := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowCredentials(true).
|
||||||
|
WithAllowedMethods("OPTIONS", "POST").
|
||||||
|
WithAllowedOrigins(allowedOrigins...).
|
||||||
|
WithEnabled(utils.IsStringInSlice(oidc.TokenEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS)
|
||||||
|
r.POST(oidc.TokenPath, policyCORSToken.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))
|
||||||
|
|
||||||
|
policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowCredentials(true).
|
||||||
|
WithAllowedMethods("OPTIONS", "GET", "POST").
|
||||||
|
WithAllowedOrigins(allowedOrigins...).
|
||||||
|
WithEnabled(utils.IsStringInSlice(oidc.UserinfoEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS)
|
||||||
|
r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
|
||||||
|
r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
|
||||||
|
|
||||||
|
policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowCredentials(true).
|
||||||
|
WithAllowedMethods("OPTIONS", "POST").
|
||||||
|
WithAllowedOrigins(allowedOrigins...).
|
||||||
|
WithEnabled(utils.IsStringInSlice(oidc.IntrospectionEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS)
|
||||||
|
r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
|
||||||
|
|
||||||
|
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||||
|
r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
|
||||||
|
r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
|
||||||
|
|
||||||
|
policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
|
||||||
|
WithAllowCredentials(true).
|
||||||
|
WithAllowedMethods("OPTIONS", "POST").
|
||||||
|
WithAllowedOrigins(allowedOrigins...).
|
||||||
|
WithEnabled(utils.IsStringInSlice(oidc.RevocationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS)
|
||||||
|
r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
|
||||||
|
|
||||||
|
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||||
|
r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
|
||||||
|
r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.NotFound = handlerNotFound(autheliaMiddleware(serveIndexHandler))
|
||||||
|
|
||||||
r.HandleMethodNotAllowed = true
|
r.HandleMethodNotAllowed = true
|
||||||
r.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) {
|
r.MethodNotAllowed = handlerMethodNotAllowed
|
||||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := middlewares.LogRequestMiddleware(r.Handler)
|
handler := middlewares.LogRequestMiddleware(r.Handler)
|
||||||
if configuration.Server.Path != "" {
|
if configuration.Server.Path != "" {
|
||||||
handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler)
|
handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
if providers.OpenIDConnect.Fosite != nil {
|
|
||||||
handlers.RegisterOIDC(r, autheliaMiddleware)
|
|
||||||
}
|
|
||||||
|
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,12 +275,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
||||||
handler := registerRoutes(configuration, providers)
|
handler := registerRoutes(configuration, providers)
|
||||||
|
|
||||||
server := &fasthttp.Server{
|
server := &fasthttp.Server{
|
||||||
ErrorHandler: autheliaErrorHandler,
|
ErrorHandler: handlerErrors,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
NoDefaultServerHeader: true,
|
NoDefaultServerHeader: true,
|
||||||
ReadBufferSize: configuration.Server.ReadBufferSize,
|
ReadBufferSize: configuration.Server.ReadBufferSize,
|
||||||
WriteBufferSize: configuration.Server.WriteBufferSize,
|
WriteBufferSize: configuration.Server.WriteBufferSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := logging.Logger()
|
logger := logging.Logger()
|
||||||
|
|
||||||
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
|
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
|
||||||
|
@ -204,9 +295,8 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
||||||
|
|
||||||
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
|
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
|
||||||
connectionType, connectionScheme = "TLS", schemeHTTPS
|
connectionType, connectionScheme = "TLS", schemeHTTPS
|
||||||
err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)
|
|
||||||
|
|
||||||
if err != nil {
|
if err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key); err != nil {
|
||||||
logger.Fatalf("unable to load certificate: %v", err)
|
logger.Fatalf("unable to load certificate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,14 +318,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
||||||
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone())
|
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
|
||||||
if err != nil {
|
|
||||||
logger.Fatalf("Error initializing listener: %s", err)
|
logger.Fatalf("Error initializing listener: %s", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
connectionType, connectionScheme = "non-TLS", schemeHTTP
|
connectionType, connectionScheme = "non-TLS", schemeHTTP
|
||||||
listener, err = net.Listen("tcp", address)
|
|
||||||
if err != nil {
|
if listener, err = net.Listen("tcp", address); err != nil {
|
||||||
logger.Fatalf("Error initializing listener: %s", err)
|
logger.Fatalf("Error initializing listener: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,11 +334,10 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
||||||
logger.Fatalf("Could not configure healthcheck: %v", err)
|
logger.Fatalf("Could not configure healthcheck: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
actualAddress := listener.Addr().String()
|
|
||||||
if configuration.Server.Path == "" {
|
if configuration.Server.Path == "" {
|
||||||
logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress)
|
logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, listener.Addr().String())
|
||||||
} else {
|
} else {
|
||||||
logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path)
|
logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, listener.Addr().String(), configuration.Server.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return server, listener
|
return server, listener
|
||||||
|
|
|
@ -48,14 +48,14 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var scheme = "https"
|
var scheme = schemeHTTPS
|
||||||
|
|
||||||
if !https {
|
if !https {
|
||||||
proto := string(ctx.XForwardedProto())
|
proto := string(ctx.XForwardedProto())
|
||||||
switch proto {
|
switch proto {
|
||||||
case "":
|
case "":
|
||||||
break
|
break
|
||||||
case "http", "https":
|
case schemeHTTP, schemeHTTPS:
|
||||||
scheme = proto
|
scheme = proto
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if host == "0.0.0.0" {
|
if host == "0.0.0.0" {
|
||||||
host = "localhost"
|
host = localhost
|
||||||
} else if strings.Contains(host, ":") {
|
} else if strings.Contains(host, ":") {
|
||||||
host = "[" + host + "]"
|
host = "[" + host + "]"
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
||||||
|
@ -145,6 +147,52 @@ func IsStringSlicesDifferentFold(a, b []string) (different bool) {
|
||||||
return isStringSlicesDifferent(a, b, IsStringInSliceFold)
|
return isStringSlicesDifferent(a, b, IsStringInSliceFold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsURLInSlice returns true if the needle url.URL is in the []url.URL haystack.
|
||||||
|
func IsURLInSlice(needle url.URL, haystack []url.URL) (has bool) {
|
||||||
|
for i := 0; i < len(haystack); i++ {
|
||||||
|
if strings.EqualFold(needle.String(), haystack[i].String()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringSliceFromURLs returns a []string from a []url.URL.
|
||||||
|
func StringSliceFromURLs(urls []url.URL) []string {
|
||||||
|
result := make([]string, len(urls))
|
||||||
|
|
||||||
|
for i := 0; i < len(urls); i++ {
|
||||||
|
result[i] = urls[i].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLsFromStringSlice returns a []url.URL from a []string.
|
||||||
|
func URLsFromStringSlice(urls []string) []url.URL {
|
||||||
|
var result []url.URL
|
||||||
|
|
||||||
|
for i := 0; i < len(urls); i++ {
|
||||||
|
u, err := url.Parse(urls[i])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, *u)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// OriginFromURL returns an origin url.URL given another url.URL.
|
||||||
|
func OriginFromURL(u url.URL) (origin url.URL) {
|
||||||
|
return url.URL{
|
||||||
|
Scheme: u.Scheme,
|
||||||
|
Host: u.Host,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string.
|
// StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string.
|
||||||
func StringSlicesDelta(before, after []string) (added, removed []string) {
|
func StringSlicesDelta(before, after []string) (added, removed []string) {
|
||||||
for _, s := range before {
|
for _, s := range before {
|
||||||
|
@ -193,6 +241,19 @@ func StringHTMLEscape(input string) (output string) {
|
||||||
return htmlEscaper.Replace(input)
|
return htmlEscaper.Replace(input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinAndCanonicalizeHeaders join header strings by a given sep.
|
||||||
|
func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
|
||||||
|
for i, header := range headers {
|
||||||
|
if i != 0 {
|
||||||
|
joined = append(joined, sep...)
|
||||||
|
}
|
||||||
|
|
||||||
|
joined = fasthttp.AppendNormalizedHeaderKey(joined, header)
|
||||||
|
}
|
||||||
|
|
||||||
|
return joined
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -171,3 +172,48 @@ func TestIsStringSliceContainsAny(t *testing.T) {
|
||||||
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
|
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
|
||||||
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
|
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStringSliceURLConversionFuncs(t *testing.T) {
|
||||||
|
urls := URLsFromStringSlice([]string{"https://google.com", "abc", "%*()@#$J(@*#$J@#($H"})
|
||||||
|
|
||||||
|
require.Len(t, urls, 2)
|
||||||
|
assert.Equal(t, "https://google.com", urls[0].String())
|
||||||
|
assert.Equal(t, "abc", urls[1].String())
|
||||||
|
|
||||||
|
strs := StringSliceFromURLs(urls)
|
||||||
|
|
||||||
|
require.Len(t, strs, 2)
|
||||||
|
assert.Equal(t, "https://google.com", strs[0])
|
||||||
|
assert.Equal(t, "abc", strs[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsURLInSlice(t *testing.T) {
|
||||||
|
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
|
||||||
|
|
||||||
|
google, err := url.Parse("https://google.com")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
microsoft, err := url.Parse("https://microsoft.com")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
example, err := url.Parse("https://example.com")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, IsURLInSlice(*google, urls))
|
||||||
|
assert.False(t, IsURLInSlice(*microsoft, urls))
|
||||||
|
assert.True(t, IsURLInSlice(*example, urls))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginFromURL(t *testing.T) {
|
||||||
|
google, err := url.Parse("https://google.com/abc?a=123#five")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
origin := OriginFromURL(*google)
|
||||||
|
assert.Equal(t, "https://google.com", origin.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinAndCanonicalizeHeaders(t *testing.T) {
|
||||||
|
result := JoinAndCanonicalizeHeaders([]byte(", "), "x-example-ONE", "X-EGG-Two")
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue