feat(oidc): pushed authorization requests (#4546)

This implements RFC9126 OAuth 2.0 Pushed Authorization Requests. See https://datatracker.ietf.org/doc/html/rfc9126 for the specification details.
pull/5033/head^2
James Elliott 2023-03-06 14:58:50 +11:00 committed by GitHub
parent 42671d3edb
commit ff6be40f5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 872 additions and 241 deletions

View File

@ -272,6 +272,23 @@ Allows [PKCE] `plain` challenges when set to `true`.
*__Security Notice:__* Changing this value is generally discouraged. Applications should use the `S256` [PKCE] challenge *__Security Notice:__* Changing this value is generally discouraged. Applications should use the `S256` [PKCE] challenge
method instead. method instead.
### pushed_authorizations
Controls the behaviour of [Pushed Authorization Requests].
#### enforce
{{< confkey type="boolean" default="false" required="no" >}}
When enabled all authorization requests must use the [Pushed Authorization Requests] flow.
#### context_lifespan
{{< confkey type="duration" default="5m" required="no" >}}
The maximum amount of time between the [Pushed Authorization Requests] flow being initiated and the generated
`request_uri` being utilized by a client.
### cors ### cors
Some [OpenID Connect 1.0] Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows Some [OpenID Connect 1.0] Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
@ -285,6 +302,7 @@ A list of endpoints to configure with cross-origin resource sharing headers. It
option is at least in this list. The potential endpoints which this can be enabled on are as follows: option is at least in this list. The potential endpoints which this can be enabled on are as follows:
* authorization * authorization
* pushed-authorization-request
* token * token
* revocation * revocation
* introspection * introspection
@ -472,6 +490,12 @@ See the [Response Modes](../../integration/openid-connect/introduction.md#respon
The authorization policy for this client: either `one_factor` or `two_factor`. The authorization policy for this client: either `one_factor` or `two_factor`.
#### enforce_par
{{< confkey type="boolean" default="false" required="no" >}}
Enforces the use of a [Pushed Authorization Requests] flow for this client.
#### enforce_pkce #### enforce_pkce
{{< confkey type="bool" default="false" required="no" >}} {{< confkey type="bool" default="false" required="no" >}}
@ -550,3 +574,4 @@ To integrate Authelia's [OpenID Connect 1.0] implementation with a relying party
[Authorization Code Flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth [Authorization Code Flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
[Subject Identifier Type]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes [Subject Identifier Type]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
[Pairwise Identifier Algorithm]: https://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg [Pairwise Identifier Algorithm]: https://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg
[Pushed Authorization Requests]: https://datatracker.ietf.org/doc/html/rfc9126

View File

@ -36,3 +36,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
| 5 | 4.35.1 | Fixed the oauth2_consent_session table to accept NULL subjects for users who are not yet signed in | | 5 | 4.35.1 | Fixed the oauth2_consent_session table to accept NULL subjects for users who are not yet signed in |
| 6 | 4.37.0 | Adjusted the OpenID Connect tables to allow pre-configured consent improvements | | 6 | 4.37.0 | Adjusted the OpenID Connect tables to allow pre-configured consent improvements |
| 7 | 4.37.3 | Fixed some schema inconsistencies most notably the MySQL/MariaDB Engine and Collation | | 7 | 4.37.3 | Fixed some schema inconsistencies most notably the MySQL/MariaDB Engine and Collation |
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |

View File

@ -210,14 +210,71 @@ These endpoints can be utilized to discover other endpoints and metadata about t
These endpoints implement OpenID Connect elements. These endpoints implement OpenID Connect elements.
| Endpoint | Path | Discovery Attribute | | Endpoint | Path | Discovery Attribute |
|:-------------------:|:-----------------------------------------------:|:----------------------:| |:-------------------------------:|:--------------------------------------------------------------:|:-------------------------------------:|
| [JSON Web Key Sets] | https://auth.example.com/jwks.json | jwks_uri | | [JSON Web Key Set] | https://auth.example.com/jwks.json | jwks_uri |
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint | | [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
| [Pushed Authorization Requests] | https://auth.example.com/api/oidc/pushed-authorization-request | pushed_authorization_request_endpoint |
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint | | [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
| [UserInfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint | | [UserInfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint | | [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint | | [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
## Security
The following information covers some security topics some users may wish to be familiar with.
#### Pushed Authorization Requests Endpoint
The [Pushed Authorization Requests] endpoint is discussed in depth in [RFC9126] as well as in the
[OAuth 2.0 Pushed Authorization Requests](https://oauth.net/2/pushed-authorization-requests/) documentation.
Essentially it's a special endpoint that takes the same parameters as the [Authorization] endpoint (including
[Proof Key Code Exchange](#proof-key-code-exchange)) with a few caveats:
1. The same [Client Authentication] mechanism required by the [Token] endpoint **MUST** be used.
2. The request **MUST** use the [HTTP POST method].
3. The request **MUST** use the `application/x-www-form-urlencoded` content type (i.e. the parameters **MUST** be in the
body, not the URI).
4. The request **MUST** occur over the back-channel.
The response of this endpoint is a JSON Object with two key-value pairs:
- `request_uri`
- `expires_in`
The `expires_in` indicates how long the `request_uri` is valid for. The `request_uri` is used as a parameter to the
[Authorization] endpoint instead of the standard parameters (as the `request_uri` parameter).
The advantages of this approach are as follows:
1. [Pushed Authorization Requests] cannot be created or influenced by any party other than the Relying Party (client).
2. Since you can force all [Authorization] requests to be initiated via [Pushed Authorization Requests] you drastically
improve the authorization flows resistance to phishing attacks (this can be done globally or on a per-client basis).
3. Since the [Pushed Authorization Requests] endpoint requires all of the same [Client Authentication] mechanisms as the
[Token] endpoint:
1. Clients using the confidential [Client Type] can't have [Pushed Authorization Requests] generated by parties who do not
have the credentials.
2. Clients using the public [Client Type] and utilizing [Proof Key Code Exchange](#proof-key-code-exchange) never
transmit the verifier over any front-channel making even the `plain` challenge method relatively secure.
#### Proof Key Code Exchange
The [Proof Key Code Exchange] mechanism is discussed in depth in [RFC7636] as well as in the
[OAuth 2.0 Proof Key Code Exchange](https://oauth.net/2/pkce/) documentation.
Essentially a random opaque value is generated by the Relying Party and optionally (but recommended) passed through a
SHA256 hash. The original value is saved by the Relying Party, and the hashed value is sent in the [Authorization]
request in the `code_verifier` parameter with the `code_challenge_method` set to `S256` (or `plain` using a bad practice
of not hashing the opaque value).
When the Relying Party requests the token from the [Token] endpoint, they must include the `code_verifier` parameter
again (in the body), but this time they send the value without it being hashed.
The advantages of this approach are as follows:
1. Provided the value was hashed it's certain that the Relying Party which generated the authorization request is the
same party as the one requesting the token or is permitted by the Relying Party to make this request.
2. Even when using the public [Client Type] there is a form of authentication on the [Token] endpoint.
[ID Token]: https://openid.net/specs/openid-connect-core-1_0.html#IDToken [ID Token]: https://openid.net/specs/openid-connect-core-1_0.html#IDToken
[Access Token]: https://datatracker.ietf.org/doc/html/rfc6749#section-1.4 [Access Token]: https://datatracker.ietf.org/doc/html/rfc6749#section-1.4
[Refresh Token]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens [Refresh Token]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
@ -230,14 +287,23 @@ These endpoints implement OpenID Connect elements.
[OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html [OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
[OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414 [OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414
[JSON Web Key Sets]: https://datatracker.ietf.org/doc/html/rfc7517#section-5 [JSON Web Key Set]: https://datatracker.ietf.org/doc/html/rfc7517#section-5
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint [Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
[Pushed Authorization Requests]: https://datatracker.ietf.org/doc/html/rfc9126
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint [Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
[UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662 [Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009 [Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
[Proof Key Code Exchange]: https://www.rfc-editor.org/rfc/rfc7636.html
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
[RFC4122]: https://datatracker.ietf.org/doc/html/rfc4122
[Subject Identifier Types]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes [Subject Identifier Types]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
[Client Authentication]: https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
[Client Type]: https://oauth.net/2/client-types/
[HTTP POST method]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/POST
[Proof Key Code Exchange]: #proof-key-code-exchange
[RFC4122]: https://datatracker.ietf.org/doc/html/rfc4122
[RFC7636]: https://datatracker.ietf.org/doc/html/rfc7636
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
[RFC9126]: https://datatracker.ietf.org/doc/html/rfc9126

File diff suppressed because one or more lines are too long

View File

@ -29,10 +29,17 @@ type OpenIDConnectConfiguration struct {
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"` EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
CORS OpenIDConnectCORSConfiguration `koanf:"cors"` CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
PAR OpenIDConnectPARConfiguration `koanf:"pushed_authorizations"`
Clients []OpenIDConnectClientConfiguration `koanf:"clients"` Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
} }
// OpenIDConnectPARConfiguration represents an OpenID Connect PAR config.
type OpenIDConnectPARConfiguration struct {
Enforce bool `koanf:"enforce"`
ContextLifespan time.Duration `koanf:"context_lifespan"`
}
// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config. // OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
type OpenIDConnectCORSConfiguration struct { type OpenIDConnectCORSConfiguration struct {
Endpoints []string `koanf:"endpoints"` Endpoints []string `koanf:"endpoints"`
@ -59,6 +66,7 @@ type OpenIDConnectClientConfiguration struct {
Policy string `koanf:"authorization_policy"` Policy string `koanf:"authorization_policy"`
EnforcePAR bool `koanf:"enforce_par"`
EnforcePKCE bool `koanf:"enforce_pkce"` EnforcePKCE bool `koanf:"enforce_pkce"`
PKCEChallengeMethod string `koanf:"pkce_challenge_method"` PKCEChallengeMethod string `koanf:"pkce_challenge_method"`

View File

@ -31,6 +31,8 @@ var Keys = []string{
"identity_providers.oidc.cors.endpoints", "identity_providers.oidc.cors.endpoints",
"identity_providers.oidc.cors.allowed_origins", "identity_providers.oidc.cors.allowed_origins",
"identity_providers.oidc.cors.allowed_origins_from_client_redirect_uris", "identity_providers.oidc.cors.allowed_origins_from_client_redirect_uris",
"identity_providers.oidc.pushed_authorizations.enforce",
"identity_providers.oidc.pushed_authorizations.context_lifespan",
"identity_providers.oidc.clients", "identity_providers.oidc.clients",
"identity_providers.oidc.clients[].id", "identity_providers.oidc.clients[].id",
"identity_providers.oidc.clients[].description", "identity_providers.oidc.clients[].description",
@ -44,6 +46,7 @@ var Keys = []string{
"identity_providers.oidc.clients[].response_types", "identity_providers.oidc.clients[].response_types",
"identity_providers.oidc.clients[].response_modes", "identity_providers.oidc.clients[].response_modes",
"identity_providers.oidc.clients[].authorization_policy", "identity_providers.oidc.clients[].authorization_policy",
"identity_providers.oidc.clients[].enforce_par",
"identity_providers.oidc.clients[].enforce_pkce", "identity_providers.oidc.clients[].enforce_pkce",
"identity_providers.oidc.clients[].pkce_challenge_method", "identity_providers.oidc.clients[].pkce_challenge_method",
"identity_providers.oidc.clients[].userinfo_signing_algorithm", "identity_providers.oidc.clients[].userinfo_signing_algorithm",

View File

@ -392,7 +392,7 @@ var (
validOIDCGrantTypes = []string{oidc.GrantTypeImplicit, oidc.GrantTypeRefreshToken, oidc.GrantTypeAuthorizationCode, oidc.GrantTypePassword, oidc.GrantTypeClientCredentials} validOIDCGrantTypes = []string{oidc.GrantTypeImplicit, oidc.GrantTypeRefreshToken, oidc.GrantTypeAuthorizationCode, oidc.GrantTypePassword, oidc.GrantTypeClientCredentials}
validOIDCResponseModes = []string{oidc.ResponseModeFormPost, oidc.ResponseModeQuery, oidc.ResponseModeFragment} validOIDCResponseModes = []string{oidc.ResponseModeFormPost, oidc.ResponseModeQuery, oidc.ResponseModeFragment}
validOIDCUserinfoAlgorithms = []string{oidc.SigningAlgorithmNone, oidc.SigningAlgorithmRSAWithSHA256} validOIDCUserinfoAlgorithms = []string{oidc.SigningAlgorithmNone, oidc.SigningAlgorithmRSAWithSHA256}
validOIDCCORSEndpoints = []string{oidc.EndpointAuthorization, oidc.EndpointToken, oidc.EndpointIntrospection, oidc.EndpointRevocation, oidc.EndpointUserinfo} validOIDCCORSEndpoints = []string{oidc.EndpointAuthorization, oidc.EndpointPushedAuthorizationRequest, oidc.EndpointToken, oidc.EndpointIntrospection, oidc.EndpointRevocation, oidc.EndpointUserinfo}
validOIDCClientConsentModes = []string{"auto", oidc.ClientConsentModeImplicit.String(), oidc.ClientConsentModeExplicit.String(), oidc.ClientConsentModePreConfigured.String()} validOIDCClientConsentModes = []string{"auto", oidc.ClientConsentModeImplicit.String(), oidc.ClientConsentModeExplicit.String(), oidc.ClientConsentModePreConfigured.String()}
) )

View File

@ -80,7 +80,7 @@ func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
require.Len(t, validator.Errors(), 1) require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'") assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'pushed-authorization-request', 'token', 'introspection', 'revocation', 'userinfo'")
} }
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) { func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {

View File

@ -53,10 +53,20 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
return return
} }
if err = client.ValidateAuthorizationPolicy(requester); err != nil { if err = client.ValidatePARPolicy(requester, ctx.Providers.OpenIDConnect.GetPushedAuthorizeRequestURIPrefix(ctx)); err != nil {
rfc := fosite.ErrorToRFC6749Error(err) rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the authorization policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription()) ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the PAR policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err)
return
}
if err = client.ValidatePKCEPolicy(requester); err != nil {
rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err) ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err)
@ -95,13 +105,13 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID) ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID)
oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(), session := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(),
userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester) userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester)
ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v", ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v",
requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims) requester.GetID(), session.ClientID, session.Subject, session.Username, session.Claims)
if responder, err = ctx.Providers.OpenIDConnect.NewAuthorizeResponse(ctx, requester, oidcSession); err != nil { if responder, err = ctx.Providers.OpenIDConnect.NewAuthorizeResponse(ctx, requester, session); err != nil {
rfc := fosite.ErrorToRFC6749Error(err) rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Authorization Response for Request with id '%s' on client with id '%s' could not be created: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription()) ctx.Logger.Errorf("Authorization Response for Request with id '%s' on client with id '%s' could not be created: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
@ -125,3 +135,62 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
ctx.Providers.OpenIDConnect.WriteAuthorizeResponse(ctx, rw, requester, responder) ctx.Providers.OpenIDConnect.WriteAuthorizeResponse(ctx, rw, requester, responder)
} }
// OpenIDConnectPushedAuthorizationRequest handles POST requests to the OAuth 2.0 Pushed Authorization Requests endpoint.
//
// RFC9126 https://www.rfc-editor.org/rfc/rfc9126.html
func OpenIDConnectPushedAuthorizationRequest(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
var (
requester fosite.AuthorizeRequester
responder fosite.PushedAuthorizeResponder
err error
)
if requester, err = ctx.Providers.OpenIDConnect.NewPushedAuthorizeRequest(ctx, r); err != nil {
rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Pushed Authorization Request failed with error: %s", rfc.WithExposeDebug(true).GetDescription())
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
return
}
var client *oidc.Client
clientID := requester.GetClient().GetID()
if client, err = ctx.Providers.OpenIDConnect.GetFullClient(clientID); err != nil {
if errors.Is(err, fosite.ErrNotFound) {
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID)
} else {
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' could not be processed: failed to find client: %+v", requester.GetID(), clientID, err)
}
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
return
}
if err = client.ValidatePKCEPolicy(requester); err != nil {
rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
return
}
if responder, err = ctx.Providers.OpenIDConnect.NewPushedAuthorizeResponse(ctx, requester, oidc.NewSession()); err != nil {
rfc := fosite.ErrorToRFC6749Error(err)
ctx.Logger.Errorf("Pushed Authorization Request failed with error: %s", rfc.WithExposeDebug(true).GetDescription())
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
return
}
ctx.Providers.OpenIDConnect.WritePushedAuthorizeResponse(ctx, rw, requester, responder)
}

View File

@ -9,9 +9,8 @@ import (
mail "net/mail" mail "net/mail"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock"
templates "github.com/authelia/authelia/v4/internal/templates" templates "github.com/authelia/authelia/v4/internal/templates"
gomock "github.com/golang/mock/gomock"
) )
// MockNotifier is a mock of Notifier interface. // MockNotifier is a mock of Notifier interface.

View File

@ -10,11 +10,10 @@ import (
reflect "reflect" reflect "reflect"
time "time" time "time"
gomock "github.com/golang/mock/gomock"
uuid "github.com/google/uuid"
model "github.com/authelia/authelia/v4/internal/model" model "github.com/authelia/authelia/v4/internal/model"
storage "github.com/authelia/authelia/v4/internal/storage" storage "github.com/authelia/authelia/v4/internal/storage"
gomock "github.com/golang/mock/gomock"
uuid "github.com/google/uuid"
) )
// MockStorage is a mock of Provider interface. // MockStorage is a mock of Provider interface.
@ -40,6 +39,7 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder return m.recorder
} }
// AppendAuthenticationLog mocks base method. // AppendAuthenticationLog mocks base method.
func (m *MockStorage) AppendAuthenticationLog(arg0 context.Context, arg1 model.AuthenticationAttempt) error { func (m *MockStorage) AppendAuthenticationLog(arg0 context.Context, arg1 model.AuthenticationAttempt) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -270,6 +270,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionByChallengeID(arg0, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1)
} }
// LoadOAuth2PARContext mocks base method.
func (m *MockStorage) LoadOAuth2PARContext(arg0 context.Context, arg1 string) (*model.OAuth2PARContext, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOAuth2PARContext", arg0, arg1)
ret0, _ := ret[0].(*model.OAuth2PARContext)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOAuth2PARContext indicates an expected call of LoadOAuth2PARContext.
func (mr *MockStorageMockRecorder) LoadOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2PARContext), arg0, arg1)
}
// LoadOAuth2Session mocks base method. // LoadOAuth2Session mocks base method.
func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) { func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -435,6 +450,20 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1)
} }
// RevokeOAuth2PARContext mocks base method.
func (m *MockStorage) RevokeOAuth2PARContext(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeOAuth2PARContext", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeOAuth2PARContext indicates an expected call of RevokeOAuth2PARContext.
func (mr *MockStorageMockRecorder) RevokeOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2PARContext), arg0, arg1)
}
// RevokeOAuth2Session mocks base method. // RevokeOAuth2Session mocks base method.
func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -576,6 +605,20 @@ func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionSubject(arg0, arg1 in
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionSubject", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionSubject), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionSubject", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionSubject), arg0, arg1)
} }
// SaveOAuth2PARContext mocks base method.
func (m *MockStorage) SaveOAuth2PARContext(arg0 context.Context, arg1 model.OAuth2PARContext) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2PARContext", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2PARContext indicates an expected call of SaveOAuth2PARContext.
func (mr *MockStorageMockRecorder) SaveOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2PARContext), arg0, arg1)
}
// SaveOAuth2Session mocks base method. // SaveOAuth2Session mocks base method.
func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error { func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -7,9 +7,8 @@ package mocks
import ( import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock"
model "github.com/authelia/authelia/v4/internal/model" model "github.com/authelia/authelia/v4/internal/model"
gomock "github.com/golang/mock/gomock"
) )
// MockTOTP is a mock of Provider interface. // MockTOTP is a mock of Provider interface.

View File

@ -7,9 +7,8 @@ package mocks
import ( import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock"
authentication "github.com/authelia/authelia/v4/internal/authentication" authentication "github.com/authelia/authelia/v4/internal/authentication"
gomock "github.com/golang/mock/gomock"
) )
// MockUserProvider is a mock of UserProvider interface. // MockUserProvider is a mock of UserProvider interface.

View File

@ -39,6 +39,14 @@ func NewOAuth2ConsentSession(subject uuid.UUID, r fosite.Requester) (consent *OA
return consent, nil return consent, nil
} }
// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI.
func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) {
return OAuth2BlacklistedJTI{
Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))),
ExpiresAt: exp,
}
}
// NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester. // NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester.
func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) { func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) {
var ( var (
@ -77,12 +85,43 @@ func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session
}, nil }, nil
} }
// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI. // NewOAuth2PARContext creates a new Pushed Authorization Request Context as a OAuth2PARContext.
func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) { func NewOAuth2PARContext(contextID string, r fosite.AuthorizeRequester) (context *OAuth2PARContext, err error) {
return OAuth2BlacklistedJTI{ var (
Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))), s *OpenIDSession
ExpiresAt: exp, ok bool
req *fosite.AuthorizeRequest
session []byte
)
if s, ok = r.GetSession().(*OpenIDSession); !ok {
return nil, fmt.Errorf("can't convert type '%T' to an *OAuth2Session", r.GetSession())
} }
if session, err = json.Marshal(s); err != nil {
return nil, err
}
var handled StringSlicePipeDelimited
if req, ok = r.(*fosite.AuthorizeRequest); ok {
handled = StringSlicePipeDelimited(req.HandledResponseTypes)
}
return &OAuth2PARContext{
Signature: contextID,
RequestID: r.GetID(),
ClientID: r.GetClient().GetID(),
RequestedAt: r.GetRequestedAt(),
Scopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
Audience: StringSlicePipeDelimited(r.GetRequestedAudience()),
HandledResponseTypes: handled,
ResponseMode: string(r.GetResponseMode()),
DefaultResponseMode: string(r.GetDefaultResponseMode()),
Revoked: false,
Form: r.GetRequestForm().Encode(),
Session: session,
}, nil
} }
// OAuth2ConsentPreConfig stores information about an OAuth2.0 Pre-Configured Consent. // OAuth2ConsentPreConfig stores information about an OAuth2.0 Pre-Configured Consent.
@ -264,6 +303,70 @@ func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, s
}, nil }, nil
} }
// OAuth2PARContext holds relevant information about a Pushed Authorization Request in order to process the authorization.
type OAuth2PARContext struct {
ID int `db:"id"`
Signature string `db:"signature"`
RequestID string `db:"request_id"`
ClientID string `db:"client_id"`
RequestedAt time.Time `db:"requested_at"`
Scopes StringSlicePipeDelimited `db:"scopes"`
Audience StringSlicePipeDelimited `db:"audience"`
HandledResponseTypes StringSlicePipeDelimited `db:"handled_response_types"`
ResponseMode string `db:"response_mode"`
DefaultResponseMode string `db:"response_mode_default"`
Revoked bool `db:"revoked"`
Form string `db:"form_data"`
Session []byte `db:"session_data"`
}
func (par *OAuth2PARContext) ToAuthorizeRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.AuthorizeRequest, err error) {
if session != nil {
if err = json.Unmarshal(par.Session, session); err != nil {
return nil, err
}
}
var (
client fosite.Client
form url.Values
)
if client, err = store.GetClient(ctx, par.ClientID); err != nil {
return nil, err
}
if form, err = url.ParseQuery(par.Form); err != nil {
return nil, err
}
request = fosite.NewAuthorizeRequest()
request.Request = fosite.Request{
ID: par.RequestID,
RequestedAt: par.RequestedAt,
Client: client,
RequestedScope: fosite.Arguments(par.Scopes),
RequestedAudience: fosite.Arguments(par.Audience),
Form: form,
Session: session,
}
if par.ResponseMode != "" {
request.ResponseMode = fosite.ResponseModeType(par.ResponseMode)
}
if par.DefaultResponseMode != "" {
request.DefaultResponseMode = fosite.ResponseModeType(par.DefaultResponseMode)
}
if len(par.HandledResponseTypes) != 0 {
request.HandledResponseTypes = fosite.Arguments(par.HandledResponseTypes)
}
return request, nil
}
// OpenIDSession holds OIDC Session information. // OpenIDSession holds OIDC Session information.
type OpenIDSession struct { type OpenIDSession struct {
*openid.DefaultSession `json:"id_token"` *openid.DefaultSession `json:"id_token"`

View File

@ -1,7 +1,7 @@
package oidc package oidc
import ( import (
"fmt" "strings"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/x/errorsx" "github.com/ory/x/errorsx"
@ -32,6 +32,8 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client)
ResponseTypes: config.ResponseTypes, ResponseTypes: config.ResponseTypes,
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeDefault}, ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeDefault},
EnforcePAR: config.EnforcePAR,
UserinfoSigningAlgorithm: config.UserinfoSigningAlgorithm, UserinfoSigningAlgorithm: config.UserinfoSigningAlgorithm,
Policy: authorization.NewLevel(config.Policy), Policy: authorization.NewLevel(config.Policy),
@ -46,22 +48,22 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client)
return client return client
} }
// ValidateAuthorizationPolicy is a helper function to validate additional policy constraints on a per-client basis. // ValidatePKCEPolicy is a helper function to validate PKCE policy constraints on a per-client basis.
func (c *Client) ValidateAuthorizationPolicy(r fosite.Requester) (err error) { func (c *Client) ValidatePKCEPolicy(r fosite.Requester) (err error) {
form := r.GetRequestForm() form := r.GetRequestForm()
if c.EnforcePKCE { if c.EnforcePKCE {
if form.Get("code_challenge") == "" { if form.Get(FormParameterCodeChallenge) == "" {
return errorsx.WithStack(fosite.ErrInvalidRequest. return errorsx.WithStack(fosite.ErrInvalidRequest.
WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing."). WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for this client.")) WithDebug("The server is configured in a way that enforces PKCE for this client."))
} }
if c.EnforcePKCEChallengeMethod { if c.EnforcePKCEChallengeMethod {
if method := form.Get("code_challenge_method"); method != c.PKCEChallengeMethod { if method := form.Get(FormParameterCodeChallengeMethod); method != c.PKCEChallengeMethod {
return errorsx.WithStack(fosite.ErrInvalidRequest. return errorsx.WithStack(fosite.ErrInvalidRequest.
WithHint(fmt.Sprintf("Client must use code_challenge_method=%s, %s is not allowed.", c.PKCEChallengeMethod, method)). WithHintf("Client must use code_challenge_method=%s, %s is not allowed.", c.PKCEChallengeMethod, method).
WithDebug(fmt.Sprintf("The server is configured in a way that enforces PKCE %s as challenge method for this client.", c.PKCEChallengeMethod))) WithDebugf("The server is configured in a way that enforces PKCE %s as challenge method for this client.", c.PKCEChallengeMethod))
} }
} }
} }
@ -69,6 +71,23 @@ func (c *Client) ValidateAuthorizationPolicy(r fosite.Requester) (err error) {
return nil return nil
} }
// ValidatePARPolicy is a helper function to validate additional policy constraints on a per-client basis.
func (c *Client) ValidatePARPolicy(r fosite.Requester, prefix string) (err error) {
form := r.GetRequestForm()
if c.EnforcePAR {
if requestURI := form.Get(FormParameterRequestURI); !strings.HasPrefix(requestURI, prefix) {
if requestURI == "" {
return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebug("The request_uri parameter was empty."))
}
return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebugf("The request_uri parameter '%s' is malformed.", requestURI))
}
}
return nil
}
// IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient. // IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient.
func (c *Client) IsAuthenticationLevelSufficient(level authentication.Level) bool { func (c *Client) IsAuthenticationLevelSufficient(level authentication.Level) bool {
if level == authentication.NotAuthenticated { if level == authentication.NotAuthenticated {
@ -105,7 +124,7 @@ func (c *Client) GetID() string {
} }
// GetHashedSecret returns the Secret. // GetHashedSecret returns the Secret.
func (c *Client) GetHashedSecret() []byte { func (c *Client) GetHashedSecret() (secret []byte) {
if c.Secret == nil { if c.Secret == nil {
return []byte(nil) return []byte(nil)
} }
@ -114,7 +133,7 @@ func (c *Client) GetHashedSecret() []byte {
} }
// GetRedirectURIs returns the RedirectURIs. // GetRedirectURIs returns the RedirectURIs.
func (c *Client) GetRedirectURIs() []string { func (c *Client) GetRedirectURIs() (redirectURIs []string) {
return c.RedirectURIs return c.RedirectURIs
} }

View File

@ -224,7 +224,7 @@ func TestNewClientPKCE(t *testing.T) {
expectedEnforcePKCE bool expectedEnforcePKCE bool
expectedEnforcePKCEChallengeMethod bool expectedEnforcePKCEChallengeMethod bool
expected string expected string
req *fosite.Request r *fosite.Request
err string err string
}{ }{
{ {
@ -288,8 +288,8 @@ func TestNewClientPKCE(t *testing.T) {
assert.Equal(t, tc.expectedEnforcePKCEChallengeMethod, client.EnforcePKCEChallengeMethod) assert.Equal(t, tc.expectedEnforcePKCEChallengeMethod, client.EnforcePKCEChallengeMethod)
assert.Equal(t, tc.expected, client.PKCEChallengeMethod) assert.Equal(t, tc.expected, client.PKCEChallengeMethod)
if tc.req != nil { if tc.r != nil {
err := client.ValidateAuthorizationPolicy(tc.req) err := client.ValidatePKCEPolicy(tc.r)
if tc.err != "" { if tc.err != "" {
assert.EqualError(t, err, tc.err) assert.EqualError(t, err, tc.err)

View File

@ -24,8 +24,8 @@ import (
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.Provider) *Config { func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.Provider) (c *Config) {
c := &Config{ c = &Config{
GlobalSecret: []byte(utils.HashSHA256FromString(config.HMACSecret)), GlobalSecret: []byte(utils.HashSHA256FromString(config.HMACSecret)),
SendDebugMessagesToClients: config.EnableClientDebugMessages, SendDebugMessagesToClients: config.EnableClientDebugMessages,
MinParameterEntropy: config.MinimumParameterEntropy, MinParameterEntropy: config.MinimumParameterEntropy,
@ -40,18 +40,23 @@ func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.P
EnforcePublicClients: config.EnforcePKCE != "never", EnforcePublicClients: config.EnforcePKCE != "never",
AllowPlainChallengeMethod: config.EnablePKCEPlainChallenge, AllowPlainChallengeMethod: config.EnablePKCEPlainChallenge,
}, },
PAR: PARConfig{
Enforced: config.PAR.Enforce,
ContextLifespan: config.PAR.ContextLifespan,
URIPrefix: urnPARPrefix,
},
Templates: templates, Templates: templates,
} }
c.Strategy.Core = &HMACCoreStrategy{ c.Strategy.Core = &HMACCoreStrategy{
Enigma: &hmac.HMACStrategy{Config: c}, Enigma: &hmac.HMACStrategy{Config: c},
Config: c, Config: c,
prefix: tokenPrefixFmt,
} }
return c return c
} }
// Config is an implementation of the fosite.Configurator.
type Config struct { type Config struct {
// GlobalSecret is the global secret used to sign and verify signatures. // GlobalSecret is the global secret used to sign and verify signatures.
GlobalSecret []byte GlobalSecret []byte
@ -68,7 +73,7 @@ type Config struct {
JWTScopeField jwt.JWTScopeFieldEnum JWTScopeField jwt.JWTScopeFieldEnum
JWTMaxDuration time.Duration JWTMaxDuration time.Duration
Hasher *AdaptiveHasher Hasher *Hasher
Hash HashConfig Hash HashConfig
Strategy StrategyConfig Strategy StrategyConfig
PAR PARConfig PAR PARConfig
@ -92,11 +97,13 @@ type Config struct {
Templates *templates.Provider Templates *templates.Provider
} }
// HashConfig holds specific fosite.Configurator information for hashing.
type HashConfig struct { type HashConfig struct {
ClientSecrets fosite.Hasher ClientSecrets fosite.Hasher
HMAC func() (h hash.Hash) HMAC func() (h hash.Hash)
} }
// StrategyConfig holds specific fosite.Configurator information for various strategies.
type StrategyConfig struct { type StrategyConfig struct {
Core oauth2.CoreStrategy Core oauth2.CoreStrategy
OpenID openid.OpenIDConnectTokenStrategy OpenID openid.OpenIDConnectTokenStrategy
@ -106,17 +113,20 @@ type StrategyConfig struct {
ClientAuthentication fosite.ClientAuthenticationStrategy ClientAuthentication fosite.ClientAuthenticationStrategy
} }
// PARConfig holds specific fosite.Configurator information for Pushed Authorization Requests.
type PARConfig struct { type PARConfig struct {
Enforced bool Enforced bool
URIPrefix string URIPrefix string
ContextLifespan time.Duration ContextLifespan time.Duration
} }
// IssuersConfig holds specific fosite.Configurator information for the issuer.
type IssuersConfig struct { type IssuersConfig struct {
IDToken string IDToken string
AccessToken string AccessToken string
} }
// HandlersConfig holds specific fosite.Configurator handlers configuration information.
type HandlersConfig struct { type HandlersConfig struct {
// ResponseMode provides an extension handler for custom response modes. // ResponseMode provides an extension handler for custom response modes.
ResponseMode fosite.ResponseModeHandler ResponseMode fosite.ResponseModeHandler
@ -137,18 +147,21 @@ type HandlersConfig struct {
PushedAuthorizeEndpoint fosite.PushedAuthorizeEndpointHandlers PushedAuthorizeEndpoint fosite.PushedAuthorizeEndpointHandlers
} }
// GrantTypeJWTBearerConfig holds specific fosite.Configurator information for the JWT Bearer Grant Type.
type GrantTypeJWTBearerConfig struct { type GrantTypeJWTBearerConfig struct {
OptionalClientAuth bool OptionalClientAuth bool
OptionalJTIClaim bool OptionalJTIClaim bool
OptionalIssuedDate bool OptionalIssuedDate bool
} }
// ProofKeyCodeExchangeConfig holds specific fosite.Configurator information for PKCE.
type ProofKeyCodeExchangeConfig struct { type ProofKeyCodeExchangeConfig struct {
Enforce bool Enforce bool
EnforcePublicClients bool EnforcePublicClients bool
AllowPlainChallengeMethod bool AllowPlainChallengeMethod bool
} }
// LifespanConfig holds specific fosite.Configurator information for various lifespans.
type LifespanConfig struct { type LifespanConfig struct {
AccessToken time.Duration AccessToken time.Duration
AuthorizeCode time.Duration AuthorizeCode time.Duration
@ -162,6 +175,7 @@ const (
PromptConsent = "consent" PromptConsent = "consent"
) )
// LoadHandlers reloads the handlers based on the current configuration.
func (c *Config) LoadHandlers(store *Store, strategy jwt.Signer) { func (c *Config) LoadHandlers(store *Store, strategy jwt.Signer) {
validator := openid.NewOpenIDConnectRequestValidator(strategy, c) validator := openid.NewOpenIDConnectRequestValidator(strategy, c)
@ -278,6 +292,10 @@ func (c *Config) LoadHandlers(store *Store, strategy jwt.Signer) {
if h, ok := handler.(fosite.RevocationHandler); ok { if h, ok := handler.(fosite.RevocationHandler); ok {
x.Revocation.Append(h) x.Revocation.Append(h)
} }
if h, ok := handler.(fosite.PushedAuthorizeEndpointHandler); ok {
x.PushedAuthorizeEndpoint.Append(h)
}
} }
c.Handlers = x c.Handlers = x
@ -533,7 +551,7 @@ func (c *Config) GetTokenURL(ctx context.Context) (tokenURL string) {
// GetSecretsHasher returns the client secrets hashing function. // GetSecretsHasher returns the client secrets hashing function.
func (c *Config) GetSecretsHasher(ctx context.Context) (hasher fosite.Hasher) { func (c *Config) GetSecretsHasher(ctx context.Context) (hasher fosite.Hasher) {
if c.Hash.ClientSecrets == nil { if c.Hash.ClientSecrets == nil {
c.Hash.ClientSecrets, _ = NewAdaptiveHasher() c.Hash.ClientSecrets, _ = NewHasher()
} }
return c.Hash.ClientSecrets return c.Hash.ClientSecrets
@ -595,7 +613,7 @@ func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool {
// GetPushedAuthorizeContextLifespan is the lifespan of the short-lived PAR context. // GetPushedAuthorizeContextLifespan is the lifespan of the short-lived PAR context.
func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) (lifespan time.Duration) { func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) (lifespan time.Duration) {
if c.PAR.ContextLifespan == 0 { if c.PAR.ContextLifespan.Seconds() == 0 {
c.PAR.ContextLifespan = lifespanPARContextDefault c.PAR.ContextLifespan = lifespanPARContextDefault
} }

View File

@ -110,6 +110,12 @@ const (
PKCEChallengeMethodSHA256 = "S256" PKCEChallengeMethodSHA256 = "S256"
) )
const (
FormParameterRequestURI = "request_uri"
FormParameterCodeChallenge = "code_challenge"
FormParameterCodeChallengeMethod = "code_challenge_method"
)
// Endpoints. // Endpoints.
const ( const (
EndpointAuthorization = "authorization" EndpointAuthorization = "authorization"
@ -117,6 +123,7 @@ const (
EndpointUserinfo = "userinfo" EndpointUserinfo = "userinfo"
EndpointIntrospection = "introspection" EndpointIntrospection = "introspection"
EndpointRevocation = "revocation" EndpointRevocation = "revocation"
EndpointPushedAuthorizationRequest = "pushed-authorization-request"
) )
// JWT Headers. // JWT Headers.
@ -126,7 +133,9 @@ const (
) )
const ( const (
tokenPrefixFmt = "authelia_%s_" //nolint:gosec tokenPrefixOrgAutheliaFmt = "authelia_%s_" //nolint:gosec
tokenPrefixOrgOryFmt = "ory_%s_" //nolint:gosec
tokenPrefixPartAccessToken = "at" tokenPrefixPartAccessToken = "at"
tokenPrefixPartRefreshToken = "rt" tokenPrefixPartRefreshToken = "rt"
tokenPrefixPartAuthorizeCode = "ac" tokenPrefixPartAuthorizeCode = "ac"
@ -146,6 +155,8 @@ const (
EndpointPathUserinfo = EndpointPathRoot + "/" + EndpointUserinfo EndpointPathUserinfo = EndpointPathRoot + "/" + EndpointUserinfo
EndpointPathIntrospection = EndpointPathRoot + "/" + EndpointIntrospection EndpointPathIntrospection = EndpointPathRoot + "/" + EndpointIntrospection
EndpointPathRevocation = EndpointPathRoot + "/" + EndpointRevocation EndpointPathRevocation = EndpointPathRoot + "/" + EndpointRevocation
EndpointPathPushedAuthorizationRequest = EndpointPathRoot + "/" + EndpointPushedAuthorizationRequest
) )
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176 // Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176

View File

@ -19,7 +19,6 @@ type HMACCoreStrategy struct {
fosite.RefreshTokenLifespanProvider fosite.RefreshTokenLifespanProvider
fosite.AuthorizeCodeLifespanProvider fosite.AuthorizeCodeLifespanProvider
} }
prefix string
} }
// AccessTokenSignature implements oauth2.AccessTokenStrategy. // AccessTokenSignature implements oauth2.AccessTokenStrategy.
@ -112,11 +111,11 @@ func (h *HMACCoreStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.R
} }
func (h *HMACCoreStrategy) getPrefix(part string) string { func (h *HMACCoreStrategy) getPrefix(part string) string {
if len(h.prefix) == 0 { return h.getCustomPrefix(tokenPrefixOrgAutheliaFmt, part)
return ""
} }
return fmt.Sprintf(h.prefix, part) func (h *HMACCoreStrategy) getCustomPrefix(tokenPrefixFmt, part string) string {
return fmt.Sprintf(tokenPrefixFmt, part)
} }
func (h *HMACCoreStrategy) setPrefix(token, part string) string { func (h *HMACCoreStrategy) setPrefix(token, part string) string {
@ -124,5 +123,9 @@ func (h *HMACCoreStrategy) setPrefix(token, part string) string {
} }
func (h *HMACCoreStrategy) trimPrefix(token, part string) string { func (h *HMACCoreStrategy) trimPrefix(token, part string) string {
if strings.HasPrefix(token, h.getCustomPrefix(tokenPrefixOrgOryFmt, part)) {
return strings.TrimPrefix(token, h.getCustomPrefix(tokenPrefixOrgOryFmt, part))
}
return strings.TrimPrefix(token, h.getPrefix(part)) return strings.TrimPrefix(token, h.getPrefix(part))
} }

View File

@ -0,0 +1,56 @@
package oidc
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHMACCoreStrategy_TrimPrefix(t *testing.T) {
testCases := []struct {
name string
have string
part string
expected string
}{
{"ShouldTrimAutheliaPrefix", "authelia_at_example", tokenPrefixPartAccessToken, "example"},
{"ShouldTrimOryPrefix", "ory_at_example", tokenPrefixPartAccessToken, "example"},
{"ShouldTrimOnlyAutheliaPrefix", "authelia_at_ory_at_example", tokenPrefixPartAccessToken, "ory_at_example"},
{"ShouldTrimOnlyOryPrefix", "ory_at_authelia_at_example", tokenPrefixPartAccessToken, "authelia_at_example"},
{"ShouldNotTrimGitHubPrefix", "gh_at_example", tokenPrefixPartAccessToken, "gh_at_example"},
}
strategy := &HMACCoreStrategy{}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, strategy.trimPrefix(tc.have, tc.part))
})
}
}
func TestHMACCoreStrategy_GetSetPrefix(t *testing.T) {
testCases := []struct {
name string
have string
expectedSet string
expectedGet string
}{
{"ShouldAddPrefix", "example", "authelia_%s_example", "authelia_%s_"},
}
strategy := &HMACCoreStrategy{}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, part := range []string{tokenPrefixPartAccessToken, tokenPrefixPartAuthorizeCode, tokenPrefixPartRefreshToken} {
t.Run(strings.ToUpper(part), func(t *testing.T) {
assert.Equal(t, fmt.Sprintf(tc.expectedSet, part), strategy.setPrefix(tc.have, part))
assert.Equal(t, fmt.Sprintf(tc.expectedGet, part), strategy.getPrefix(part))
})
}
})
}
}

View File

@ -1,7 +1,11 @@
package oidc package oidc
import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration. // NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration.
func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clients map[string]*Client) (config OpenIDConnectWellKnownConfiguration) { func NewOpenIDConnectWellKnownConfiguration(c *schema.OpenIDConnectConfiguration, clients map[string]*Client) (config OpenIDConnectWellKnownConfiguration) {
config = OpenIDConnectWellKnownConfiguration{ config = OpenIDConnectWellKnownConfiguration{
CommonDiscoveryOptions: CommonDiscoveryOptions{ CommonDiscoveryOptions: CommonDiscoveryOptions{
SubjectTypesSupported: []string{ SubjectTypesSupported: []string{
@ -78,6 +82,9 @@ func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clien
SigningAlgorithmRSAWithSHA256, SigningAlgorithmRSAWithSHA256,
}, },
}, },
PushedAuthorizationDiscoveryOptions: PushedAuthorizationDiscoveryOptions{
RequirePushedAuthorizationRequests: c.PAR.Enforce,
},
} }
var pairwise, public bool var pairwise, public bool
@ -96,7 +103,7 @@ func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clien
config.SubjectTypesSupported = append(config.SubjectTypesSupported, SubjectTypePairwise) config.SubjectTypesSupported = append(config.SubjectTypesSupported, SubjectTypePairwise)
} }
if enablePKCEPlainChallenge { if c.EnablePKCEPlainChallenge {
config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, PKCEChallengeMethodPlain) config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, PKCEChallengeMethodPlain)
} }

View File

@ -4,12 +4,15 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/authelia/authelia/v4/internal/configuration/schema"
) )
func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) { func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
pkcePlainChallenge bool pkcePlainChallenge bool
enforcePAR bool
clients map[string]*Client clients map[string]*Client
expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string
@ -63,7 +66,14 @@ func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
actual := NewOpenIDConnectWellKnownConfiguration(tc.pkcePlainChallenge, tc.clients) c := schema.OpenIDConnectConfiguration{
EnablePKCEPlainChallenge: tc.pkcePlainChallenge,
PAR: schema.OpenIDConnectPARConfiguration{
Enforce: tc.enforcePAR,
},
}
actual := NewOpenIDConnectWellKnownConfiguration(&c, tc.clients)
for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported { for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported {
assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod) assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod)
} }

View File

@ -9,11 +9,27 @@ import (
var errPasswordsDoNotMatch = errors.New("the passwords don't match") var errPasswordsDoNotMatch = errors.New("the passwords don't match")
var ( var (
// ErrIssuerCouldNotDerive is sent when the issuer couldn't be determined from the headers.
ErrIssuerCouldNotDerive = fosite.ErrServerError.WithHint("Could not safely derive the issuer.") ErrIssuerCouldNotDerive = fosite.ErrServerError.WithHint("Could not safely derive the issuer.")
// ErrSubjectCouldNotLookup is sent when the Subject Identifier for a user couldn't be generated or obtained from the database.
ErrSubjectCouldNotLookup = fosite.ErrServerError.WithHint("Could not lookup user subject.") ErrSubjectCouldNotLookup = fosite.ErrServerError.WithHint("Could not lookup user subject.")
// ErrConsentCouldNotPerform is sent when the Consent Session couldn't be performed for varying reasons.
ErrConsentCouldNotPerform = fosite.ErrServerError.WithHint("Could not perform consent.") ErrConsentCouldNotPerform = fosite.ErrServerError.WithHint("Could not perform consent.")
// ErrConsentCouldNotGenerate is sent when the Consent Session failed to be generated for some reason, usually a failed UUIDv4 generation.
ErrConsentCouldNotGenerate = fosite.ErrServerError.WithHint("Could not generate the consent session.") ErrConsentCouldNotGenerate = fosite.ErrServerError.WithHint("Could not generate the consent session.")
// ErrConsentCouldNotSave is sent when the Consent Session couldn't be saved to the database.
ErrConsentCouldNotSave = fosite.ErrServerError.WithHint("Could not save the consent session.") ErrConsentCouldNotSave = fosite.ErrServerError.WithHint("Could not save the consent session.")
// ErrConsentCouldNotLookup is sent when the Consent ID is not a known UUID.
ErrConsentCouldNotLookup = fosite.ErrServerError.WithHint("Failed to lookup the consent session.") ErrConsentCouldNotLookup = fosite.ErrServerError.WithHint("Failed to lookup the consent session.")
// ErrConsentMalformedChallengeID is sent when the Consent ID is not a UUID.
ErrConsentMalformedChallengeID = fosite.ErrServerError.WithHint("Malformed consent session challenge ID.") ErrConsentMalformedChallengeID = fosite.ErrServerError.WithHint("Malformed consent session challenge ID.")
// ErrPAREnforcedClientMissingPAR is sent when a client has EnforcePAR configured but the Authorization Request was not Pushed.
ErrPAREnforcedClientMissingPAR = fosite.ErrInvalidRequest.WithHint("Pushed Authorization Requests are enforced for this client but no such request was sent.")
) )

View File

@ -8,8 +8,9 @@ import (
"github.com/go-crypt/crypt/algorithm/plaintext" "github.com/go-crypt/crypt/algorithm/plaintext"
) )
func NewAdaptiveHasher() (hasher *AdaptiveHasher, err error) { // NewHasher returns a new Hasher.
hasher = &AdaptiveHasher{} func NewHasher() (hasher *Hasher, err error) {
hasher = &Hasher{}
if hasher.decoder, err = crypt.NewDefaultDecoder(); err != nil { if hasher.decoder, err = crypt.NewDefaultDecoder(); err != nil {
return nil, err return nil, err
@ -22,13 +23,13 @@ func NewAdaptiveHasher() (hasher *AdaptiveHasher, err error) {
return hasher, nil return hasher, nil
} }
// AdaptiveHasher implements the fosite.Hasher interface without an actual hashing algo. // Hasher implements the fosite.Hasher interface and adaptively compares hashes.
type AdaptiveHasher struct { type Hasher struct {
decoder algorithm.DecoderRegister decoder algorithm.DecoderRegister
} }
// Compare compares the hash with the data and returns an error if they don't match. // Compare compares the hash with the data and returns an error if they don't match.
func (h *AdaptiveHasher) Compare(_ context.Context, hash, data []byte) (err error) { func (h Hasher) Compare(_ context.Context, hash, data []byte) (err error) {
var digest algorithm.Digest var digest algorithm.Digest
if digest, err = h.decoder.Decode(string(hash)); err != nil { if digest, err = h.decoder.Decode(string(hash)); err != nil {
@ -43,6 +44,6 @@ func (h *AdaptiveHasher) Compare(_ context.Context, hash, data []byte) (err erro
} }
// Hash creates a new hash from data. // Hash creates a new hash from data.
func (h *AdaptiveHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) { func (h Hasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
return data, nil return data, nil
} }

View File

@ -9,7 +9,7 @@ import (
) )
func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) { func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
hasher, err := NewAdaptiveHasher() hasher, err := NewHasher()
require.NoError(t, err) require.NoError(t, err)
@ -22,7 +22,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
} }
func TestShouldNotRaiseErrorOnEqualPasswordsPlainTextWithSeparator(t *testing.T) { func TestShouldNotRaiseErrorOnEqualPasswordsPlainTextWithSeparator(t *testing.T) {
hasher, err := NewAdaptiveHasher() hasher, err := NewHasher()
require.NoError(t, err) require.NoError(t, err)
@ -35,7 +35,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainTextWithSeparator(t *testing.T)
} }
func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) { func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
hasher, err := NewAdaptiveHasher() hasher, err := NewHasher()
require.NoError(t, err) require.NoError(t, err)
@ -48,7 +48,7 @@ func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
} }
func TestShouldHashPassword(t *testing.T) { func TestShouldHashPassword(t *testing.T) {
hasher := AdaptiveHasher{} hasher := Hasher{}
data := []byte("abc") data := []byte("abc")

View File

@ -37,7 +37,7 @@ func NewOpenIDConnectProvider(config *schema.OpenIDConnectConfiguration, store s
provider.Config.LoadHandlers(provider.Store, provider.KeyManager.Strategy()) provider.Config.LoadHandlers(provider.Store, provider.KeyManager.Strategy())
provider.discovery = NewOpenIDConnectWellKnownConfiguration(config.EnablePKCEPlainChallenge, provider.Store.clients) provider.discovery = NewOpenIDConnectWellKnownConfiguration(config, provider.Store.clients)
return provider, nil return provider, nil
} }
@ -50,12 +50,12 @@ func (p *OpenIDConnectProvider) GetOAuth2WellKnownConfiguration(issuer string) O
} }
options.Issuer = issuer options.Issuer = issuer
options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs) options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs)
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization) options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization)
options.PushedAuthorizationRequestEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathPushedAuthorizationRequest)
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation) options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation)
return options return options
@ -72,14 +72,14 @@ func (p *OpenIDConnectProvider) GetOpenIDConnectWellKnownConfiguration(issuer st
} }
options.Issuer = issuer options.Issuer = issuer
options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs) options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs)
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization) options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization)
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation) options.PushedAuthorizationRequestEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathPushedAuthorizationRequest)
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
options.UserinfoEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathUserinfo) options.UserinfoEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathUserinfo)
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation)
return options return options
} }

View File

@ -165,7 +165,7 @@ func (s *Store) InvalidateAuthorizeCodeSession(ctx context.Context, code string)
// This implements a portion of oauth2.AuthorizeCodeStorage. // This implements a portion of oauth2.AuthorizeCodeStorage.
func (s *Store) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) { func (s *Store) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) {
// TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions. // TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions.
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session) return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session)
} }
// CreateAccessTokenSession stores the authorization request for a given access token. // CreateAccessTokenSession stores the authorization request for a given access token.
@ -190,7 +190,7 @@ func (s *Store) RevokeAccessToken(ctx context.Context, requestID string) (err er
// GetAccessTokenSession gets the authorization request for a given access token. // GetAccessTokenSession gets the authorization request for a given access token.
// This implements a portion of oauth2.AccessTokenStorage. // This implements a portion of oauth2.AccessTokenStorage.
func (s *Store) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { func (s *Store) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session) return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session)
} }
// CreateRefreshTokenSession stores the authorization request for a given refresh token. // CreateRefreshTokenSession stores the authorization request for a given refresh token.
@ -223,7 +223,7 @@ func (s *Store) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestI
// GetRefreshTokenSession gets the authorization request for a given refresh token. // GetRefreshTokenSession gets the authorization request for a given refresh token.
// This implements a portion of oauth2.RefreshTokenStorage. // This implements a portion of oauth2.RefreshTokenStorage.
func (s *Store) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { func (s *Store) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session) return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session)
} }
// CreatePKCERequestSession stores the authorization request for a given PKCE request. // CreatePKCERequestSession stores the authorization request for a given PKCE request.
@ -241,7 +241,7 @@ func (s *Store) DeletePKCERequestSession(ctx context.Context, signature string)
// GetPKCERequestSession gets the authorization request for a given PKCE request. // GetPKCERequestSession gets the authorization request for a given PKCE request.
// This implements a portion of pkce.PKCERequestStorage. // This implements a portion of pkce.PKCERequestStorage.
func (s *Store) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) { func (s *Store) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session) return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session)
} }
// CreateOpenIDConnectSession creates an open id connect session for a given authorize code. // CreateOpenIDConnectSession creates an open id connect session for a given authorize code.
@ -263,7 +263,37 @@ func (s *Store) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode st
// - or an arbitrary error if an error occurred. // - or an arbitrary error if an error occurred.
// This implements a portion of openid.OpenIDConnectRequestStorage. // This implements a portion of openid.OpenIDConnectRequestStorage.
func (s *Store) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) { func (s *Store) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession()) return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession())
}
// CreatePARSession stores the pushed authorization request context. The requestURI is used to derive the key.
// This implements a portion of fosite.PARStorage.
func (s *Store) CreatePARSession(ctx context.Context, requestURI string, request fosite.AuthorizeRequester) (err error) {
var par *model.OAuth2PARContext
if par, err = model.NewOAuth2PARContext(requestURI, request); err != nil {
return err
}
return s.provider.SaveOAuth2PARContext(ctx, *par)
}
// GetPARSession gets the push authorization request context. The caller is expected to merge the AuthorizeRequest.
// This implements a portion of fosite.PARStorage.
func (s *Store) GetPARSession(ctx context.Context, requestURI string) (request fosite.AuthorizeRequester, err error) {
var par *model.OAuth2PARContext
if par, err = s.provider.LoadOAuth2PARContext(ctx, requestURI); err != nil {
return nil, err
}
return par.ToAuthorizeRequest(ctx, NewSession(), s)
}
// DeletePARSession deletes the context.
// This implements a portion of fosite.PARStorage.
func (s *Store) DeletePARSession(ctx context.Context, requestURI string) (err error) {
return s.provider.RevokeOAuth2PARContext(ctx, requestURI)
} }
// IsJWTUsed implements an interface required for RFC7523. // IsJWTUsed implements an interface required for RFC7523.
@ -280,7 +310,7 @@ func (s *Store) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Tim
return s.SetClientAssertionJWT(ctx, jti, exp) return s.SetClientAssertionJWT(ctx, jti, exp)
} }
func (s *Store) loadSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) { func (s *Store) loadRequesterBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) {
var ( var (
sessionModel *model.OAuth2Session sessionModel *model.OAuth2Session
) )

View File

@ -119,6 +119,8 @@ type Client struct {
ResponseTypes []string ResponseTypes []string
ResponseModes []fosite.ResponseModeType ResponseModes []fosite.ResponseModeType
EnforcePAR bool
UserinfoSigningAlgorithm string UserinfoSigningAlgorithm string
Policy authorization.Level Policy authorization.Level

View File

@ -331,6 +331,15 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
r.GET("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization)))) r.GET("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
r.POST("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization)))) r.POST("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
policyCORSPAR := middlewares.NewCORSPolicyBuilder().
WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
WithAllowedOrigins(allowedOrigins...).
WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)).
Build()
r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS)
r.POST(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest))))
policyCORSToken := middlewares.NewCORSPolicyBuilder(). policyCORSToken := middlewares.NewCORSPolicyBuilder().
WithAllowCredentials(true). WithAllowCredentials(true).
WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost). WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).

View File

@ -13,15 +13,16 @@ const (
tableUserPreferences = "user_preferences" tableUserPreferences = "user_preferences"
tableWebauthnDevices = "webauthn_devices" tableWebauthnDevices = "webauthn_devices"
tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
tableOAuth2ConsentSession = "oauth2_consent_session" tableOAuth2ConsentSession = "oauth2_consent_session"
tableOAuth2ConsentPreConfiguration = "oauth2_consent_preconfiguration" tableOAuth2ConsentPreConfiguration = "oauth2_consent_preconfiguration"
tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential. tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential. tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session" tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti" tableOAuth2PARContext = "oauth2_par_context"
tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
tableMigrations = "migrations" tableMigrations = "migrations"
tableEncryption = "encryption" tableEncryption = "encryption"
@ -32,26 +33,29 @@ type OAuth2SessionType int
// Representation of specific OAuth 2.0 session types. // Representation of specific OAuth 2.0 session types.
const ( const (
OAuth2SessionTypeAuthorizeCode OAuth2SessionType = iota OAuth2SessionTypeAccessToken OAuth2SessionType = iota
OAuth2SessionTypeAccessToken OAuth2SessionTypeAuthorizeCode
OAuth2SessionTypeRefreshToken
OAuth2SessionTypePKCEChallenge
OAuth2SessionTypeOpenIDConnect OAuth2SessionTypeOpenIDConnect
OAuth2SessionTypePAR
OAuth2SessionTypePKCEChallenge
OAuth2SessionTypeRefreshToken
) )
// String returns a string representation of this OAuth2SessionType. // String returns a string representation of this OAuth2SessionType.
func (s OAuth2SessionType) String() string { func (s OAuth2SessionType) String() string {
switch s { switch s {
case OAuth2SessionTypeAuthorizeCode:
return "authorization code"
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
return "access token" return "access token"
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
return "refresh token" return "authorization code"
case OAuth2SessionTypePKCEChallenge:
return "pkce challenge"
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
return "openid connect" return "openid connect"
case OAuth2SessionTypePAR:
return "pushed authorization request context"
case OAuth2SessionTypePKCEChallenge:
return "pkce challenge"
case OAuth2SessionTypeRefreshToken:
return "refresh token"
default: default:
return "invalid" return "invalid"
} }
@ -60,16 +64,18 @@ func (s OAuth2SessionType) String() string {
// Table returns the table name for this session type. // Table returns the table name for this session type.
func (s OAuth2SessionType) Table() string { func (s OAuth2SessionType) Table() string {
switch s { switch s {
case OAuth2SessionTypeAuthorizeCode:
return tableOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
return tableOAuth2AccessTokenSession return tableOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
return tableOAuth2RefreshTokenSession return tableOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
return tableOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
return tableOAuth2OpenIDConnectSession return tableOAuth2OpenIDConnectSession
case OAuth2SessionTypePAR:
return tableOAuth2PARContext
case OAuth2SessionTypePKCEChallenge:
return tableOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
return tableOAuth2RefreshTokenSession
default: default:
return "" return ""
} }
@ -119,7 +125,7 @@ const (
) )
var ( var (
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`) reMigration = regexp.MustCompile(`^V(?P<Version>\d{4})\.(?P<Name>[^.]+)\.(?P<Provider>(all|sqlite|postgres|mysql))\.(?P<Direction>(up|down))\.sql$`)
) )
const ( const (

View File

@ -130,15 +130,15 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
} }
func scanMigration(m string) (migration model.SchemaMigration, err error) { func scanMigration(m string) (migration model.SchemaMigration, err error) {
result := reMigration.FindStringSubmatch(m) if !reMigration.MatchString(m) {
if result == nil || len(result) != 5 {
return model.SchemaMigration{}, errors.New("invalid migration: could not parse the format") return model.SchemaMigration{}, errors.New("invalid migration: could not parse the format")
} }
result := reMigration.FindStringSubmatch(m)
migration = model.SchemaMigration{ migration = model.SchemaMigration{
Name: strings.ReplaceAll(result[2], "_", " "), Name: strings.ReplaceAll(result[reMigration.SubexpIndex("Name")], "_", " "),
Provider: result[3], Provider: result[reMigration.SubexpIndex("Provider")],
} }
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)) data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
@ -148,22 +148,22 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) {
migration.Query = string(data) migration.Query = string(data)
switch result[4] { switch direction := result[reMigration.SubexpIndex("Direction")]; direction {
case "up": case "up":
migration.Up = true migration.Up = true
case "down": case "down":
migration.Up = false migration.Up = false
default: default:
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4]) return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in Direction group '%s' must be up or down", direction)
} }
migration.Version, _ = strconv.Atoi(result[1]) migration.Version, _ = strconv.Atoi(result[reMigration.SubexpIndex("Version")])
switch migration.Provider { switch migration.Provider {
case providerAll, providerSQLite, providerMySQL, providerPostgres: case providerAll, providerSQLite, providerMySQL, providerPostgres:
break break
default: default:
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3]) return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in Provider group '%s' must be all, sqlite, postgres, or mysql", migration.Provider)
} }
return migration, nil return migration, nil

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS oauth2_par_context;

View File

@ -0,0 +1,17 @@
CREATE TABLE IF NOT EXISTS oauth2_par_context (
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
scopes TEXT NOT NULL,
audience TEXT NOT NULL,
handled_response_types TEXT NOT NULL,
response_mode TEXT NOT NULL,
response_mode_default TEXT NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);

View File

@ -0,0 +1,17 @@
CREATE TABLE IF NOT EXISTS oauth2_par_context (
id SERIAL CONSTRAINT oauth2_par_context_pkey PRIMARY KEY,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
scopes TEXT NOT NULL,
audience TEXT NULL DEFAULT '',
handled_response_types TEXT NOT NULL DEFAULT '',
response_mode TEXT NOT NULL DEFAULT '',
response_mode_default TEXT NOT NULL DEFAULT '',
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL
);
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);

View File

@ -0,0 +1,17 @@
CREATE TABLE IF NOT EXISTS oauth2_par_context (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
signature VARCHAR(255) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
scopes TEXT NOT NULL,
audience TEXT NOT NULL,
handled_response_types TEXT NOT NULL,
response_mode TEXT NOT NULL,
response_mode_default TEXT NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL
);
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);

View File

@ -9,7 +9,7 @@ import (
const ( const (
// This is the latest schema version for the purpose of tests. // This is the latest schema version for the purpose of tests.
LatestVersion = 7 LatestVersion = 8
) )
func TestShouldObtainCorrectUpMigrations(t *testing.T) { func TestShouldObtainCorrectUpMigrations(t *testing.T) {

View File

@ -24,8 +24,8 @@ type Provider interface {
LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error)
SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error)
LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error)
LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error)
LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error)
SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error)
@ -65,6 +65,10 @@ type Provider interface {
DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error)
SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error)
LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error)
RevokeOAuth2PARContext(ctx context.Context, signature string) (err error)
SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error)
LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error)

View File

@ -70,6 +70,13 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier),
sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
sqlInsertOAuth2PARContext: fmt.Sprintf(queryFmtInsertOAuth2PARContext, tableOAuth2PARContext),
sqlSelectOAuth2PARContext: fmt.Sprintf(queryFmtSelectOAuth2PARContext, tableOAuth2PARContext),
sqlRevokeOAuth2PARContext: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PARContext),
sqlInsertOAuth2ConsentPreConfiguration: fmt.Sprintf(queryFmtInsertOAuth2ConsentPreConfiguration, tableOAuth2ConsentPreConfiguration), sqlInsertOAuth2ConsentPreConfiguration: fmt.Sprintf(queryFmtInsertOAuth2ConsentPreConfiguration, tableOAuth2ConsentPreConfiguration),
sqlSelectOAuth2ConsentPreConfigurations: fmt.Sprintf(queryFmtSelectOAuth2ConsentPreConfigurations, tableOAuth2ConsentPreConfiguration), sqlSelectOAuth2ConsentPreConfigurations: fmt.Sprintf(queryFmtSelectOAuth2ConsentPreConfigurations, tableOAuth2ConsentPreConfiguration),
@ -79,13 +86,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession), sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession), sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession), sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession), sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession), sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
@ -93,19 +93,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession), sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession), sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession), sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession), sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession), sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession), sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession), sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
@ -114,8 +107,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession), sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
@ -224,13 +228,18 @@ type SQLProvider struct {
sqlDeactivateOAuth2AccessTokenSession string sqlDeactivateOAuth2AccessTokenSession string
sqlDeactivateOAuth2AccessTokenSessionByRequestID string sqlDeactivateOAuth2AccessTokenSessionByRequestID string
// Table: oauth2_refresh_token_session. // Table: oauth2_openid_connect_session.
sqlInsertOAuth2RefreshTokenSession string sqlInsertOAuth2OpenIDConnectSession string
sqlSelectOAuth2RefreshTokenSession string sqlSelectOAuth2OpenIDConnectSession string
sqlRevokeOAuth2RefreshTokenSession string sqlRevokeOAuth2OpenIDConnectSession string
sqlRevokeOAuth2RefreshTokenSessionByRequestID string sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
sqlDeactivateOAuth2RefreshTokenSession string sqlDeactivateOAuth2OpenIDConnectSession string
sqlDeactivateOAuth2RefreshTokenSessionByRequestID string sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
// Table: oauth2_par_context.
sqlInsertOAuth2PARContext string
sqlSelectOAuth2PARContext string
sqlRevokeOAuth2PARContext string
// Table: oauth2_pkce_request_session. // Table: oauth2_pkce_request_session.
sqlInsertOAuth2PKCERequestSession string sqlInsertOAuth2PKCERequestSession string
@ -240,13 +249,13 @@ type SQLProvider struct {
sqlDeactivateOAuth2PKCERequestSession string sqlDeactivateOAuth2PKCERequestSession string
sqlDeactivateOAuth2PKCERequestSessionByRequestID string sqlDeactivateOAuth2PKCERequestSessionByRequestID string
// Table: oauth2_openid_connect_session. // Table: oauth2_refresh_token_session.
sqlInsertOAuth2OpenIDConnectSession string sqlInsertOAuth2RefreshTokenSession string
sqlSelectOAuth2OpenIDConnectSession string sqlSelectOAuth2RefreshTokenSession string
sqlRevokeOAuth2OpenIDConnectSession string sqlRevokeOAuth2RefreshTokenSession string
sqlRevokeOAuth2OpenIDConnectSessionByRequestID string sqlRevokeOAuth2RefreshTokenSessionByRequestID string
sqlDeactivateOAuth2OpenIDConnectSession string sqlDeactivateOAuth2RefreshTokenSession string
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
sqlUpsertOAuth2BlacklistedJTI string sqlUpsertOAuth2BlacklistedJTI string
sqlSelectOAuth2BlacklistedJTI string sqlSelectOAuth2BlacklistedJTI string
@ -339,19 +348,19 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
} }
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database. // SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) { func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil { if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil {
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err) return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err)
} }
return nil return nil
} }
// LoadUserOpaqueIdentifier selects an opaque user identifier from the database. // LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) { func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) {
opaqueID = &model.UserOpaqueIdentifier{} subject = &model.UserOpaqueIdentifier{}
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil { if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifier, identifier); err != nil {
switch { switch {
case errors.Is(err, sql.ErrNoRows): case errors.Is(err, sql.ErrNoRows):
return nil, nil return nil, nil
@ -360,11 +369,11 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID u
} }
} }
return opaqueID, nil return subject, nil
} }
// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database. // LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database.
func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) { func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) {
var rows *sqlx.Rows var rows *sqlx.Rows
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil { if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil {
@ -380,17 +389,17 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs
return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err) return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err)
} }
opaqueIDs = append(opaqueIDs, *opaqueID) identifiers = append(identifiers, *opaqueID)
} }
return opaqueIDs, nil return identifiers, nil
} }
// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. // LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) {
opaqueID = &model.UserOpaqueIdentifier{} subject = &model.UserOpaqueIdentifier{}
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil { if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
switch { switch {
case errors.Is(err, sql.ErrNoRows): case errors.Is(err, sql.ErrNoRows):
return nil, nil return nil, nil
@ -399,7 +408,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s
} }
} }
return opaqueID, nil return subject, nil
} }
// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session. // SaveOAuth2ConsentSession inserts an OAuth2.0 consent session.
@ -496,22 +505,22 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlInsertOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlInsertOAuth2AccessTokenSession query = p.sqlInsertOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlInsertOAuth2RefreshTokenSession query = p.sqlInsertOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlInsertOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlInsertOAuth2OpenIDConnectSession query = p.sqlInsertOAuth2OpenIDConnectSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlInsertOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlInsertOAuth2RefreshTokenSession
default: default:
return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType) return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
} }
if session.Session, err = p.encrypt(session.Session); err != nil { if session.Session, err = p.encrypt(session.Session); err != nil {
return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) return fmt.Errorf("error encrypting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
} }
_, err = p.db.ExecContext(ctx, query, _, err = p.db.ExecContext(ctx, query,
@ -532,16 +541,16 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSession query = p.sqlRevokeOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2RefreshTokenSession query = p.sqlRevokeOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSession query = p.sqlRevokeOAuth2OpenIDConnectSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSession
default: default:
return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
} }
@ -558,16 +567,16 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
default: default:
return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
} }
@ -584,16 +593,16 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSession query = p.sqlDeactivateOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2RefreshTokenSession query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSession query = p.sqlDeactivateOAuth2OpenIDConnectSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSession
default: default:
return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String()) return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
} }
@ -610,16 +619,16 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
default: default:
return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String()) return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
} }
@ -636,16 +645,16 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
var query string var query string
switch sessionType { switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlSelectOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken: case OAuth2SessionTypeAccessToken:
query = p.sqlSelectOAuth2AccessTokenSession query = p.sqlSelectOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken: case OAuth2SessionTypeAuthorizeCode:
query = p.sqlSelectOAuth2RefreshTokenSession query = p.sqlSelectOAuth2AuthorizeCodeSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlSelectOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect: case OAuth2SessionTypeOpenIDConnect:
query = p.sqlSelectOAuth2OpenIDConnectSession query = p.sqlSelectOAuth2OpenIDConnectSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlSelectOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlSelectOAuth2RefreshTokenSession
default: default:
return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String()) return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String())
} }
@ -663,6 +672,45 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
return session, nil return session, nil
} }
// SaveOAuth2PARContext save a OAuth2PARContext to the database.
func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
if par.Session, err = p.encrypt(par.Session); err != nil {
return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
}
if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2PARContext,
par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes,
par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session); err != nil {
return fmt.Errorf("error inserting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
}
return nil
}
// LoadOAuth2PARContext loads a OAuth2PARContext from the database.
func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) {
par = &model.OAuth2PARContext{}
if err = p.db.GetContext(ctx, par, p.sqlSelectOAuth2PARContext, signature); err != nil {
return nil, fmt.Errorf("error selecting oauth2 pushed authorization request context with signature '%s': %w", signature, err)
}
if par.Session, err = p.decrypt(par.Session); err != nil {
return nil, fmt.Errorf("error decrypting oauth2 oauth2 pushed authorization request context data with signature '%s' and request id '%s': %w", signature, par.RequestID, err)
}
return par, nil
}
// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database.
func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil {
return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err)
}
return nil
}
// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database. // SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
@ -762,7 +810,7 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database. // SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
if config.Secret, err = p.encrypt(config.Secret); err != nil { if config.Secret, err = p.encrypt(config.Secret); err != nil {
return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err) return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err)
} }
if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
@ -806,7 +854,7 @@ func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string
} }
if config.Secret, err = p.decrypt(config.Secret); err != nil { if config.Secret, err = p.decrypt(config.Secret); err != nil {
return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err) return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err)
} }
return config, nil return config, nil
@ -836,7 +884,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
// SaveWebauthnDevice saves a registered Webauthn device. // SaveWebauthnDevice saves a registered Webauthn device.
func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) { func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) {
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil { if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err) return fmt.Errorf("error encrypting Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
} }
if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice, if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice,

View File

@ -87,13 +87,6 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted) provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted)
provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID) provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID)
provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession) provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession)
provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession) provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession)
provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID) provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID)
@ -101,12 +94,23 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID) provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID)
provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession) provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession)
provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession) provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession) provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID) provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession) provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID) provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession) provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession)
provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession)
provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID)
provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession)
provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID)
provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)
provider.sqlInsertOAuth2PARContext = provider.db.Rebind(provider.sqlInsertOAuth2PARContext)
provider.sqlRevokeOAuth2PARContext = provider.db.Rebind(provider.sqlRevokeOAuth2PARContext)
provider.sqlSelectOAuth2PARContext = provider.db.Rebind(provider.sqlSelectOAuth2PARContext)
provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession) provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession)
provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession) provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession)
@ -115,12 +119,12 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID) provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID)
provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession) provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession)
provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession) provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession)
provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession) provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession)
provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID) provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID)
provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession) provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession)
provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID) provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID)
provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession) provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession)
provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI) provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI)

View File

@ -314,6 +314,19 @@ const (
SET active = FALSE SET active = FALSE
WHERE request_id = ?;` WHERE request_id = ?;`
queryFmtSelectOAuth2PARContext = `
SELECT id, signature, request_id, client_id, requested_at, scopes, audience,
handled_response_types, response_mode, response_mode_default, revoked,
form_data, session_data
FROM %s
WHERE signature = ? AND revoked = FALSE;`
queryFmtInsertOAuth2PARContext = `
INSERT INTO %s (signature, request_id, client_id, requested_at, scopes, audience,
handled_response_types, response_mode, response_mode_default, revoked,
form_data, session_data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtSelectOAuth2BlacklistedJTI = ` queryFmtSelectOAuth2BlacklistedJTI = `
SELECT id, signature, expires_at SELECT id, signature, expires_at
FROM %s FROM %s

View File

@ -1132,6 +1132,7 @@ func (s *CLISuite) TestStorage05ShouldChangeEncryptionKey() {
s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_par_context): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (totp_configurations): FAILURE\n\t\tInvalid Rows: 4\n\t\tTotal Rows: 4\n") s.Assert().Contains(output, "\n\n\tTable (totp_configurations): FAILURE\n\t\tInvalid Rows: 4\n\t\tTotal Rows: 4\n")
s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
@ -1149,6 +1150,7 @@ func (s *CLISuite) TestStorage05ShouldChangeEncryptionKey() {
s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (oauth2_par_context): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
s.Assert().Contains(output, "\n\n\tTable (totp_configurations): SUCCESS\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 4\n") s.Assert().Contains(output, "\n\n\tTable (totp_configurations): SUCCESS\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 4\n")
s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n") s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")

View File

@ -104,8 +104,13 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool
// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles. // IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles.
func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) { func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) {
return IsStringSliceContainsAnyF(needles, haystack, IsStringInSlice)
}
// IsStringSliceContainsAnyF checks if the haystack contains any of the strings in the needles using the isInSlice func.
func IsStringSliceContainsAnyF(needles []string, haystack []string, isInSlice func(needle string, haystack []string) bool) (inSlice bool) {
for _, n := range needles { for _, n := range needles {
if IsStringInSlice(n, haystack) { if isInSlice(n, haystack) {
return true return true
} }
} }