feat(oidc): pushed authorization requests (#4546)
This implements RFC9126 OAuth 2.0 Pushed Authorization Requests. See https://datatracker.ietf.org/doc/html/rfc9126 for the specification details.pull/5033/head^2
parent
42671d3edb
commit
ff6be40f5e
|
@ -272,6 +272,23 @@ Allows [PKCE] `plain` challenges when set to `true`.
|
|||
*__Security Notice:__* Changing this value is generally discouraged. Applications should use the `S256` [PKCE] challenge
|
||||
method instead.
|
||||
|
||||
### pushed_authorizations
|
||||
|
||||
Controls the behaviour of [Pushed Authorization Requests].
|
||||
|
||||
#### enforce
|
||||
|
||||
{{< confkey type="boolean" default="false" required="no" >}}
|
||||
|
||||
When enabled all authorization requests must use the [Pushed Authorization Requests] flow.
|
||||
|
||||
#### context_lifespan
|
||||
|
||||
{{< confkey type="duration" default="5m" required="no" >}}
|
||||
|
||||
The maximum amount of time between the [Pushed Authorization Requests] flow being initiated and the generated
|
||||
`request_uri` being utilized by a client.
|
||||
|
||||
### cors
|
||||
|
||||
Some [OpenID Connect 1.0] Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
|
||||
|
@ -285,6 +302,7 @@ A list of endpoints to configure with cross-origin resource sharing headers. It
|
|||
option is at least in this list. The potential endpoints which this can be enabled on are as follows:
|
||||
|
||||
* authorization
|
||||
* pushed-authorization-request
|
||||
* token
|
||||
* revocation
|
||||
* introspection
|
||||
|
@ -472,6 +490,12 @@ See the [Response Modes](../../integration/openid-connect/introduction.md#respon
|
|||
|
||||
The authorization policy for this client: either `one_factor` or `two_factor`.
|
||||
|
||||
#### enforce_par
|
||||
|
||||
{{< confkey type="boolean" default="false" required="no" >}}
|
||||
|
||||
Enforces the use of a [Pushed Authorization Requests] flow for this client.
|
||||
|
||||
#### enforce_pkce
|
||||
|
||||
{{< confkey type="bool" default="false" required="no" >}}
|
||||
|
@ -550,3 +574,4 @@ To integrate Authelia's [OpenID Connect 1.0] implementation with a relying party
|
|||
[Authorization Code Flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
|
||||
[Subject Identifier Type]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
|
||||
[Pairwise Identifier Algorithm]: https://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg
|
||||
[Pushed Authorization Requests]: https://datatracker.ietf.org/doc/html/rfc9126
|
||||
|
|
|
@ -36,3 +36,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
|
|||
| 5 | 4.35.1 | Fixed the oauth2_consent_session table to accept NULL subjects for users who are not yet signed in |
|
||||
| 6 | 4.37.0 | Adjusted the OpenID Connect tables to allow pre-configured consent improvements |
|
||||
| 7 | 4.37.3 | Fixed some schema inconsistencies most notably the MySQL/MariaDB Engine and Collation |
|
||||
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
|
||||
|
|
|
@ -210,14 +210,71 @@ These endpoints can be utilized to discover other endpoints and metadata about t
|
|||
These endpoints implement OpenID Connect elements.
|
||||
|
||||
| Endpoint | Path | Discovery Attribute |
|
||||
|:-------------------:|:-----------------------------------------------:|:----------------------:|
|
||||
| [JSON Web Key Sets] | https://auth.example.com/jwks.json | jwks_uri |
|
||||
|:-------------------------------:|:--------------------------------------------------------------:|:-------------------------------------:|
|
||||
| [JSON Web Key Set] | https://auth.example.com/jwks.json | jwks_uri |
|
||||
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
|
||||
| [Pushed Authorization Requests] | https://auth.example.com/api/oidc/pushed-authorization-request | pushed_authorization_request_endpoint |
|
||||
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
|
||||
| [UserInfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
|
||||
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
|
||||
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
|
||||
|
||||
## Security
|
||||
|
||||
The following information covers some security topics some users may wish to be familiar with.
|
||||
|
||||
#### Pushed Authorization Requests Endpoint
|
||||
|
||||
The [Pushed Authorization Requests] endpoint is discussed in depth in [RFC9126] as well as in the
|
||||
[OAuth 2.0 Pushed Authorization Requests](https://oauth.net/2/pushed-authorization-requests/) documentation.
|
||||
|
||||
Essentially it's a special endpoint that takes the same parameters as the [Authorization] endpoint (including
|
||||
[Proof Key Code Exchange](#proof-key-code-exchange)) with a few caveats:
|
||||
|
||||
1. The same [Client Authentication] mechanism required by the [Token] endpoint **MUST** be used.
|
||||
2. The request **MUST** use the [HTTP POST method].
|
||||
3. The request **MUST** use the `application/x-www-form-urlencoded` content type (i.e. the parameters **MUST** be in the
|
||||
body, not the URI).
|
||||
4. The request **MUST** occur over the back-channel.
|
||||
|
||||
The response of this endpoint is a JSON Object with two key-value pairs:
|
||||
- `request_uri`
|
||||
- `expires_in`
|
||||
|
||||
The `expires_in` indicates how long the `request_uri` is valid for. The `request_uri` is used as a parameter to the
|
||||
[Authorization] endpoint instead of the standard parameters (as the `request_uri` parameter).
|
||||
|
||||
The advantages of this approach are as follows:
|
||||
|
||||
1. [Pushed Authorization Requests] cannot be created or influenced by any party other than the Relying Party (client).
|
||||
2. Since you can force all [Authorization] requests to be initiated via [Pushed Authorization Requests] you drastically
|
||||
improve the authorization flows resistance to phishing attacks (this can be done globally or on a per-client basis).
|
||||
3. Since the [Pushed Authorization Requests] endpoint requires all of the same [Client Authentication] mechanisms as the
|
||||
[Token] endpoint:
|
||||
1. Clients using the confidential [Client Type] can't have [Pushed Authorization Requests] generated by parties who do not
|
||||
have the credentials.
|
||||
2. Clients using the public [Client Type] and utilizing [Proof Key Code Exchange](#proof-key-code-exchange) never
|
||||
transmit the verifier over any front-channel making even the `plain` challenge method relatively secure.
|
||||
|
||||
#### Proof Key Code Exchange
|
||||
|
||||
The [Proof Key Code Exchange] mechanism is discussed in depth in [RFC7636] as well as in the
|
||||
[OAuth 2.0 Proof Key Code Exchange](https://oauth.net/2/pkce/) documentation.
|
||||
|
||||
Essentially a random opaque value is generated by the Relying Party and optionally (but recommended) passed through a
|
||||
SHA256 hash. The original value is saved by the Relying Party, and the hashed value is sent in the [Authorization]
|
||||
request in the `code_verifier` parameter with the `code_challenge_method` set to `S256` (or `plain` using a bad practice
|
||||
of not hashing the opaque value).
|
||||
|
||||
When the Relying Party requests the token from the [Token] endpoint, they must include the `code_verifier` parameter
|
||||
again (in the body), but this time they send the value without it being hashed.
|
||||
|
||||
The advantages of this approach are as follows:
|
||||
|
||||
1. Provided the value was hashed it's certain that the Relying Party which generated the authorization request is the
|
||||
same party as the one requesting the token or is permitted by the Relying Party to make this request.
|
||||
2. Even when using the public [Client Type] there is a form of authentication on the [Token] endpoint.
|
||||
|
||||
[ID Token]: https://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||
[Access Token]: https://datatracker.ietf.org/doc/html/rfc6749#section-1.4
|
||||
[Refresh Token]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
|
@ -230,14 +287,23 @@ These endpoints implement OpenID Connect elements.
|
|||
[OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
[OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414
|
||||
|
||||
[JSON Web Key Sets]: https://datatracker.ietf.org/doc/html/rfc7517#section-5
|
||||
[JSON Web Key Set]: https://datatracker.ietf.org/doc/html/rfc7517#section-5
|
||||
|
||||
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||
[Pushed Authorization Requests]: https://datatracker.ietf.org/doc/html/rfc9126
|
||||
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
[UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
|
||||
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
|
||||
[Proof Key Code Exchange]: https://www.rfc-editor.org/rfc/rfc7636.html
|
||||
|
||||
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
||||
[RFC4122]: https://datatracker.ietf.org/doc/html/rfc4122
|
||||
[Subject Identifier Types]: https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
|
||||
[Client Authentication]: https://datatracker.ietf.org/doc/html/rfc6749#section-2.3
|
||||
[Client Type]: https://oauth.net/2/client-types/
|
||||
[HTTP POST method]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/POST
|
||||
[Proof Key Code Exchange]: #proof-key-code-exchange
|
||||
|
||||
[RFC4122]: https://datatracker.ietf.org/doc/html/rfc4122
|
||||
[RFC7636]: https://datatracker.ietf.org/doc/html/rfc7636
|
||||
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
||||
[RFC9126]: https://datatracker.ietf.org/doc/html/rfc9126
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -29,10 +29,17 @@ type OpenIDConnectConfiguration struct {
|
|||
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
|
||||
|
||||
CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
|
||||
PAR OpenIDConnectPARConfiguration `koanf:"pushed_authorizations"`
|
||||
|
||||
Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
|
||||
}
|
||||
|
||||
// OpenIDConnectPARConfiguration represents an OpenID Connect PAR config.
|
||||
type OpenIDConnectPARConfiguration struct {
|
||||
Enforce bool `koanf:"enforce"`
|
||||
ContextLifespan time.Duration `koanf:"context_lifespan"`
|
||||
}
|
||||
|
||||
// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
|
||||
type OpenIDConnectCORSConfiguration struct {
|
||||
Endpoints []string `koanf:"endpoints"`
|
||||
|
@ -59,6 +66,7 @@ type OpenIDConnectClientConfiguration struct {
|
|||
|
||||
Policy string `koanf:"authorization_policy"`
|
||||
|
||||
EnforcePAR bool `koanf:"enforce_par"`
|
||||
EnforcePKCE bool `koanf:"enforce_pkce"`
|
||||
|
||||
PKCEChallengeMethod string `koanf:"pkce_challenge_method"`
|
||||
|
|
|
@ -31,6 +31,8 @@ var Keys = []string{
|
|||
"identity_providers.oidc.cors.endpoints",
|
||||
"identity_providers.oidc.cors.allowed_origins",
|
||||
"identity_providers.oidc.cors.allowed_origins_from_client_redirect_uris",
|
||||
"identity_providers.oidc.pushed_authorizations.enforce",
|
||||
"identity_providers.oidc.pushed_authorizations.context_lifespan",
|
||||
"identity_providers.oidc.clients",
|
||||
"identity_providers.oidc.clients[].id",
|
||||
"identity_providers.oidc.clients[].description",
|
||||
|
@ -44,6 +46,7 @@ var Keys = []string{
|
|||
"identity_providers.oidc.clients[].response_types",
|
||||
"identity_providers.oidc.clients[].response_modes",
|
||||
"identity_providers.oidc.clients[].authorization_policy",
|
||||
"identity_providers.oidc.clients[].enforce_par",
|
||||
"identity_providers.oidc.clients[].enforce_pkce",
|
||||
"identity_providers.oidc.clients[].pkce_challenge_method",
|
||||
"identity_providers.oidc.clients[].userinfo_signing_algorithm",
|
||||
|
|
|
@ -392,7 +392,7 @@ var (
|
|||
validOIDCGrantTypes = []string{oidc.GrantTypeImplicit, oidc.GrantTypeRefreshToken, oidc.GrantTypeAuthorizationCode, oidc.GrantTypePassword, oidc.GrantTypeClientCredentials}
|
||||
validOIDCResponseModes = []string{oidc.ResponseModeFormPost, oidc.ResponseModeQuery, oidc.ResponseModeFragment}
|
||||
validOIDCUserinfoAlgorithms = []string{oidc.SigningAlgorithmNone, oidc.SigningAlgorithmRSAWithSHA256}
|
||||
validOIDCCORSEndpoints = []string{oidc.EndpointAuthorization, oidc.EndpointToken, oidc.EndpointIntrospection, oidc.EndpointRevocation, oidc.EndpointUserinfo}
|
||||
validOIDCCORSEndpoints = []string{oidc.EndpointAuthorization, oidc.EndpointPushedAuthorizationRequest, oidc.EndpointToken, oidc.EndpointIntrospection, oidc.EndpointRevocation, oidc.EndpointUserinfo}
|
||||
validOIDCClientConsentModes = []string{"auto", oidc.ClientConsentModeImplicit.String(), oidc.ClientConsentModeExplicit.String(), oidc.ClientConsentModePreConfigured.String()}
|
||||
)
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
|
|||
|
||||
require.Len(t, validator.Errors(), 1)
|
||||
|
||||
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'")
|
||||
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'pushed-authorization-request', 'token', 'introspection', 'revocation', 'userinfo'")
|
||||
}
|
||||
|
||||
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
||||
|
|
|
@ -53,10 +53,20 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
|
|||
return
|
||||
}
|
||||
|
||||
if err = client.ValidateAuthorizationPolicy(requester); err != nil {
|
||||
if err = client.ValidatePARPolicy(requester, ctx.Providers.OpenIDConnect.GetPushedAuthorizeRequestURIPrefix(ctx)); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the authorization policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
|
||||
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the PAR policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
|
||||
|
||||
ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = client.ValidatePKCEPolicy(requester); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
|
||||
|
||||
ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
|
@ -95,13 +105,13 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
|
|||
|
||||
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID)
|
||||
|
||||
oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(),
|
||||
session := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(),
|
||||
userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester)
|
||||
|
||||
ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v",
|
||||
requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims)
|
||||
requester.GetID(), session.ClientID, session.Subject, session.Username, session.Claims)
|
||||
|
||||
if responder, err = ctx.Providers.OpenIDConnect.NewAuthorizeResponse(ctx, requester, oidcSession); err != nil {
|
||||
if responder, err = ctx.Providers.OpenIDConnect.NewAuthorizeResponse(ctx, requester, session); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Authorization Response for Request with id '%s' on client with id '%s' could not be created: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
|
||||
|
@ -125,3 +135,62 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
|
|||
|
||||
ctx.Providers.OpenIDConnect.WriteAuthorizeResponse(ctx, rw, requester, responder)
|
||||
}
|
||||
|
||||
// OpenIDConnectPushedAuthorizationRequest handles POST requests to the OAuth 2.0 Pushed Authorization Requests endpoint.
|
||||
//
|
||||
// RFC9126 https://www.rfc-editor.org/rfc/rfc9126.html
|
||||
func OpenIDConnectPushedAuthorizationRequest(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
requester fosite.AuthorizeRequester
|
||||
responder fosite.PushedAuthorizeResponder
|
||||
err error
|
||||
)
|
||||
|
||||
if requester, err = ctx.Providers.OpenIDConnect.NewPushedAuthorizeRequest(ctx, r); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Pushed Authorization Request failed with error: %s", rfc.WithExposeDebug(true).GetDescription())
|
||||
|
||||
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var client *oidc.Client
|
||||
|
||||
clientID := requester.GetClient().GetID()
|
||||
|
||||
if client, err = ctx.Providers.OpenIDConnect.GetFullClient(clientID); err != nil {
|
||||
if errors.Is(err, fosite.ErrNotFound) {
|
||||
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID)
|
||||
} else {
|
||||
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' could not be processed: failed to find client: %+v", requester.GetID(), clientID, err)
|
||||
}
|
||||
|
||||
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = client.ValidatePKCEPolicy(requester); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription())
|
||||
|
||||
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if responder, err = ctx.Providers.OpenIDConnect.NewPushedAuthorizeResponse(ctx, requester, oidc.NewSession()); err != nil {
|
||||
rfc := fosite.ErrorToRFC6749Error(err)
|
||||
|
||||
ctx.Logger.Errorf("Pushed Authorization Request failed with error: %s", rfc.WithExposeDebug(true).GetDescription())
|
||||
|
||||
ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Providers.OpenIDConnect.WritePushedAuthorizeResponse(ctx, rw, requester, responder)
|
||||
}
|
||||
|
|
|
@ -9,9 +9,8 @@ import (
|
|||
mail "net/mail"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
templates "github.com/authelia/authelia/v4/internal/templates"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockNotifier is a mock of Notifier interface.
|
||||
|
|
|
@ -10,11 +10,10 @@ import (
|
|||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
uuid "github.com/google/uuid"
|
||||
|
||||
model "github.com/authelia/authelia/v4/internal/model"
|
||||
storage "github.com/authelia/authelia/v4/internal/storage"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
uuid "github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockStorage is a mock of Provider interface.
|
||||
|
@ -40,6 +39,7 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
|
||||
// AppendAuthenticationLog mocks base method.
|
||||
func (m *MockStorage) AppendAuthenticationLog(arg0 context.Context, arg1 model.AuthenticationAttempt) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -270,6 +270,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionByChallengeID(arg0, a
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadOAuth2PARContext mocks base method.
|
||||
func (m *MockStorage) LoadOAuth2PARContext(arg0 context.Context, arg1 string) (*model.OAuth2PARContext, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadOAuth2PARContext", arg0, arg1)
|
||||
ret0, _ := ret[0].(*model.OAuth2PARContext)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadOAuth2PARContext indicates an expected call of LoadOAuth2PARContext.
|
||||
func (mr *MockStorageMockRecorder) LoadOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2PARContext), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadOAuth2Session mocks base method.
|
||||
func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -435,6 +450,20 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 inte
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1)
|
||||
}
|
||||
|
||||
// RevokeOAuth2PARContext mocks base method.
|
||||
func (m *MockStorage) RevokeOAuth2PARContext(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevokeOAuth2PARContext", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RevokeOAuth2PARContext indicates an expected call of RevokeOAuth2PARContext.
|
||||
func (mr *MockStorageMockRecorder) RevokeOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2PARContext), arg0, arg1)
|
||||
}
|
||||
|
||||
// RevokeOAuth2Session mocks base method.
|
||||
func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -576,6 +605,20 @@ func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionSubject(arg0, arg1 in
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionSubject", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionSubject), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveOAuth2PARContext mocks base method.
|
||||
func (m *MockStorage) SaveOAuth2PARContext(arg0 context.Context, arg1 model.OAuth2PARContext) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveOAuth2PARContext", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveOAuth2PARContext indicates an expected call of SaveOAuth2PARContext.
|
||||
func (mr *MockStorageMockRecorder) SaveOAuth2PARContext(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2PARContext", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2PARContext), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveOAuth2Session mocks base method.
|
||||
func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -7,9 +7,8 @@ package mocks
|
|||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
model "github.com/authelia/authelia/v4/internal/model"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockTOTP is a mock of Provider interface.
|
||||
|
|
|
@ -7,9 +7,8 @@ package mocks
|
|||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
authentication "github.com/authelia/authelia/v4/internal/authentication"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUserProvider is a mock of UserProvider interface.
|
||||
|
|
|
@ -39,6 +39,14 @@ func NewOAuth2ConsentSession(subject uuid.UUID, r fosite.Requester) (consent *OA
|
|||
return consent, nil
|
||||
}
|
||||
|
||||
// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI.
|
||||
func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) {
|
||||
return OAuth2BlacklistedJTI{
|
||||
Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))),
|
||||
ExpiresAt: exp,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester.
|
||||
func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) {
|
||||
var (
|
||||
|
@ -77,12 +85,43 @@ func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session
|
|||
}, nil
|
||||
}
|
||||
|
||||
// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI.
|
||||
func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) {
|
||||
return OAuth2BlacklistedJTI{
|
||||
Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))),
|
||||
ExpiresAt: exp,
|
||||
// NewOAuth2PARContext creates a new Pushed Authorization Request Context as a OAuth2PARContext.
|
||||
func NewOAuth2PARContext(contextID string, r fosite.AuthorizeRequester) (context *OAuth2PARContext, err error) {
|
||||
var (
|
||||
s *OpenIDSession
|
||||
ok bool
|
||||
req *fosite.AuthorizeRequest
|
||||
session []byte
|
||||
)
|
||||
|
||||
if s, ok = r.GetSession().(*OpenIDSession); !ok {
|
||||
return nil, fmt.Errorf("can't convert type '%T' to an *OAuth2Session", r.GetSession())
|
||||
}
|
||||
|
||||
if session, err = json.Marshal(s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var handled StringSlicePipeDelimited
|
||||
|
||||
if req, ok = r.(*fosite.AuthorizeRequest); ok {
|
||||
handled = StringSlicePipeDelimited(req.HandledResponseTypes)
|
||||
}
|
||||
|
||||
return &OAuth2PARContext{
|
||||
Signature: contextID,
|
||||
RequestID: r.GetID(),
|
||||
ClientID: r.GetClient().GetID(),
|
||||
RequestedAt: r.GetRequestedAt(),
|
||||
Scopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
|
||||
Audience: StringSlicePipeDelimited(r.GetRequestedAudience()),
|
||||
HandledResponseTypes: handled,
|
||||
ResponseMode: string(r.GetResponseMode()),
|
||||
DefaultResponseMode: string(r.GetDefaultResponseMode()),
|
||||
Revoked: false,
|
||||
Form: r.GetRequestForm().Encode(),
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OAuth2ConsentPreConfig stores information about an OAuth2.0 Pre-Configured Consent.
|
||||
|
@ -264,6 +303,70 @@ func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, s
|
|||
}, nil
|
||||
}
|
||||
|
||||
// OAuth2PARContext holds relevant information about a Pushed Authorization Request in order to process the authorization.
|
||||
type OAuth2PARContext struct {
|
||||
ID int `db:"id"`
|
||||
Signature string `db:"signature"`
|
||||
RequestID string `db:"request_id"`
|
||||
ClientID string `db:"client_id"`
|
||||
RequestedAt time.Time `db:"requested_at"`
|
||||
Scopes StringSlicePipeDelimited `db:"scopes"`
|
||||
Audience StringSlicePipeDelimited `db:"audience"`
|
||||
HandledResponseTypes StringSlicePipeDelimited `db:"handled_response_types"`
|
||||
ResponseMode string `db:"response_mode"`
|
||||
DefaultResponseMode string `db:"response_mode_default"`
|
||||
Revoked bool `db:"revoked"`
|
||||
Form string `db:"form_data"`
|
||||
Session []byte `db:"session_data"`
|
||||
}
|
||||
|
||||
func (par *OAuth2PARContext) ToAuthorizeRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.AuthorizeRequest, err error) {
|
||||
if session != nil {
|
||||
if err = json.Unmarshal(par.Session, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
client fosite.Client
|
||||
form url.Values
|
||||
)
|
||||
|
||||
if client, err = store.GetClient(ctx, par.ClientID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if form, err = url.ParseQuery(par.Form); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request = fosite.NewAuthorizeRequest()
|
||||
|
||||
request.Request = fosite.Request{
|
||||
ID: par.RequestID,
|
||||
RequestedAt: par.RequestedAt,
|
||||
Client: client,
|
||||
RequestedScope: fosite.Arguments(par.Scopes),
|
||||
RequestedAudience: fosite.Arguments(par.Audience),
|
||||
Form: form,
|
||||
Session: session,
|
||||
}
|
||||
|
||||
if par.ResponseMode != "" {
|
||||
request.ResponseMode = fosite.ResponseModeType(par.ResponseMode)
|
||||
}
|
||||
|
||||
if par.DefaultResponseMode != "" {
|
||||
request.DefaultResponseMode = fosite.ResponseModeType(par.DefaultResponseMode)
|
||||
}
|
||||
|
||||
if len(par.HandledResponseTypes) != 0 {
|
||||
request.HandledResponseTypes = fosite.Arguments(par.HandledResponseTypes)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// OpenIDSession holds OIDC Session information.
|
||||
type OpenIDSession struct {
|
||||
*openid.DefaultSession `json:"id_token"`
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/x/errorsx"
|
||||
|
@ -32,6 +32,8 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client)
|
|||
ResponseTypes: config.ResponseTypes,
|
||||
ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeDefault},
|
||||
|
||||
EnforcePAR: config.EnforcePAR,
|
||||
|
||||
UserinfoSigningAlgorithm: config.UserinfoSigningAlgorithm,
|
||||
|
||||
Policy: authorization.NewLevel(config.Policy),
|
||||
|
@ -46,22 +48,22 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client)
|
|||
return client
|
||||
}
|
||||
|
||||
// ValidateAuthorizationPolicy is a helper function to validate additional policy constraints on a per-client basis.
|
||||
func (c *Client) ValidateAuthorizationPolicy(r fosite.Requester) (err error) {
|
||||
// ValidatePKCEPolicy is a helper function to validate PKCE policy constraints on a per-client basis.
|
||||
func (c *Client) ValidatePKCEPolicy(r fosite.Requester) (err error) {
|
||||
form := r.GetRequestForm()
|
||||
|
||||
if c.EnforcePKCE {
|
||||
if form.Get("code_challenge") == "" {
|
||||
if form.Get(FormParameterCodeChallenge) == "" {
|
||||
return errorsx.WithStack(fosite.ErrInvalidRequest.
|
||||
WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing.").
|
||||
WithDebug("The server is configured in a way that enforces PKCE for this client."))
|
||||
}
|
||||
|
||||
if c.EnforcePKCEChallengeMethod {
|
||||
if method := form.Get("code_challenge_method"); method != c.PKCEChallengeMethod {
|
||||
if method := form.Get(FormParameterCodeChallengeMethod); method != c.PKCEChallengeMethod {
|
||||
return errorsx.WithStack(fosite.ErrInvalidRequest.
|
||||
WithHint(fmt.Sprintf("Client must use code_challenge_method=%s, %s is not allowed.", c.PKCEChallengeMethod, method)).
|
||||
WithDebug(fmt.Sprintf("The server is configured in a way that enforces PKCE %s as challenge method for this client.", c.PKCEChallengeMethod)))
|
||||
WithHintf("Client must use code_challenge_method=%s, %s is not allowed.", c.PKCEChallengeMethod, method).
|
||||
WithDebugf("The server is configured in a way that enforces PKCE %s as challenge method for this client.", c.PKCEChallengeMethod))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -69,6 +71,23 @@ func (c *Client) ValidateAuthorizationPolicy(r fosite.Requester) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ValidatePARPolicy is a helper function to validate additional policy constraints on a per-client basis.
|
||||
func (c *Client) ValidatePARPolicy(r fosite.Requester, prefix string) (err error) {
|
||||
form := r.GetRequestForm()
|
||||
|
||||
if c.EnforcePAR {
|
||||
if requestURI := form.Get(FormParameterRequestURI); !strings.HasPrefix(requestURI, prefix) {
|
||||
if requestURI == "" {
|
||||
return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebug("The request_uri parameter was empty."))
|
||||
}
|
||||
|
||||
return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebugf("The request_uri parameter '%s' is malformed.", requestURI))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient.
|
||||
func (c *Client) IsAuthenticationLevelSufficient(level authentication.Level) bool {
|
||||
if level == authentication.NotAuthenticated {
|
||||
|
@ -105,7 +124,7 @@ func (c *Client) GetID() string {
|
|||
}
|
||||
|
||||
// GetHashedSecret returns the Secret.
|
||||
func (c *Client) GetHashedSecret() []byte {
|
||||
func (c *Client) GetHashedSecret() (secret []byte) {
|
||||
if c.Secret == nil {
|
||||
return []byte(nil)
|
||||
}
|
||||
|
@ -114,7 +133,7 @@ func (c *Client) GetHashedSecret() []byte {
|
|||
}
|
||||
|
||||
// GetRedirectURIs returns the RedirectURIs.
|
||||
func (c *Client) GetRedirectURIs() []string {
|
||||
func (c *Client) GetRedirectURIs() (redirectURIs []string) {
|
||||
return c.RedirectURIs
|
||||
}
|
||||
|
||||
|
|
|
@ -224,7 +224,7 @@ func TestNewClientPKCE(t *testing.T) {
|
|||
expectedEnforcePKCE bool
|
||||
expectedEnforcePKCEChallengeMethod bool
|
||||
expected string
|
||||
req *fosite.Request
|
||||
r *fosite.Request
|
||||
err string
|
||||
}{
|
||||
{
|
||||
|
@ -288,8 +288,8 @@ func TestNewClientPKCE(t *testing.T) {
|
|||
assert.Equal(t, tc.expectedEnforcePKCEChallengeMethod, client.EnforcePKCEChallengeMethod)
|
||||
assert.Equal(t, tc.expected, client.PKCEChallengeMethod)
|
||||
|
||||
if tc.req != nil {
|
||||
err := client.ValidateAuthorizationPolicy(tc.req)
|
||||
if tc.r != nil {
|
||||
err := client.ValidatePKCEPolicy(tc.r)
|
||||
|
||||
if tc.err != "" {
|
||||
assert.EqualError(t, err, tc.err)
|
||||
|
|
|
@ -24,8 +24,8 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.Provider) *Config {
|
||||
c := &Config{
|
||||
func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.Provider) (c *Config) {
|
||||
c = &Config{
|
||||
GlobalSecret: []byte(utils.HashSHA256FromString(config.HMACSecret)),
|
||||
SendDebugMessagesToClients: config.EnableClientDebugMessages,
|
||||
MinParameterEntropy: config.MinimumParameterEntropy,
|
||||
|
@ -40,18 +40,23 @@ func NewConfig(config *schema.OpenIDConnectConfiguration, templates *templates.P
|
|||
EnforcePublicClients: config.EnforcePKCE != "never",
|
||||
AllowPlainChallengeMethod: config.EnablePKCEPlainChallenge,
|
||||
},
|
||||
PAR: PARConfig{
|
||||
Enforced: config.PAR.Enforce,
|
||||
ContextLifespan: config.PAR.ContextLifespan,
|
||||
URIPrefix: urnPARPrefix,
|
||||
},
|
||||
Templates: templates,
|
||||
}
|
||||
|
||||
c.Strategy.Core = &HMACCoreStrategy{
|
||||
Enigma: &hmac.HMACStrategy{Config: c},
|
||||
Config: c,
|
||||
prefix: tokenPrefixFmt,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Config is an implementation of the fosite.Configurator.
|
||||
type Config struct {
|
||||
// GlobalSecret is the global secret used to sign and verify signatures.
|
||||
GlobalSecret []byte
|
||||
|
@ -68,7 +73,7 @@ type Config struct {
|
|||
JWTScopeField jwt.JWTScopeFieldEnum
|
||||
JWTMaxDuration time.Duration
|
||||
|
||||
Hasher *AdaptiveHasher
|
||||
Hasher *Hasher
|
||||
Hash HashConfig
|
||||
Strategy StrategyConfig
|
||||
PAR PARConfig
|
||||
|
@ -92,11 +97,13 @@ type Config struct {
|
|||
Templates *templates.Provider
|
||||
}
|
||||
|
||||
// HashConfig holds specific fosite.Configurator information for hashing.
|
||||
type HashConfig struct {
|
||||
ClientSecrets fosite.Hasher
|
||||
HMAC func() (h hash.Hash)
|
||||
}
|
||||
|
||||
// StrategyConfig holds specific fosite.Configurator information for various strategies.
|
||||
type StrategyConfig struct {
|
||||
Core oauth2.CoreStrategy
|
||||
OpenID openid.OpenIDConnectTokenStrategy
|
||||
|
@ -106,17 +113,20 @@ type StrategyConfig struct {
|
|||
ClientAuthentication fosite.ClientAuthenticationStrategy
|
||||
}
|
||||
|
||||
// PARConfig holds specific fosite.Configurator information for Pushed Authorization Requests.
|
||||
type PARConfig struct {
|
||||
Enforced bool
|
||||
URIPrefix string
|
||||
ContextLifespan time.Duration
|
||||
}
|
||||
|
||||
// IssuersConfig holds specific fosite.Configurator information for the issuer.
|
||||
type IssuersConfig struct {
|
||||
IDToken string
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
// HandlersConfig holds specific fosite.Configurator handlers configuration information.
|
||||
type HandlersConfig struct {
|
||||
// ResponseMode provides an extension handler for custom response modes.
|
||||
ResponseMode fosite.ResponseModeHandler
|
||||
|
@ -137,18 +147,21 @@ type HandlersConfig struct {
|
|||
PushedAuthorizeEndpoint fosite.PushedAuthorizeEndpointHandlers
|
||||
}
|
||||
|
||||
// GrantTypeJWTBearerConfig holds specific fosite.Configurator information for the JWT Bearer Grant Type.
|
||||
type GrantTypeJWTBearerConfig struct {
|
||||
OptionalClientAuth bool
|
||||
OptionalJTIClaim bool
|
||||
OptionalIssuedDate bool
|
||||
}
|
||||
|
||||
// ProofKeyCodeExchangeConfig holds specific fosite.Configurator information for PKCE.
|
||||
type ProofKeyCodeExchangeConfig struct {
|
||||
Enforce bool
|
||||
EnforcePublicClients bool
|
||||
AllowPlainChallengeMethod bool
|
||||
}
|
||||
|
||||
// LifespanConfig holds specific fosite.Configurator information for various lifespans.
|
||||
type LifespanConfig struct {
|
||||
AccessToken time.Duration
|
||||
AuthorizeCode time.Duration
|
||||
|
@ -162,6 +175,7 @@ const (
|
|||
PromptConsent = "consent"
|
||||
)
|
||||
|
||||
// LoadHandlers reloads the handlers based on the current configuration.
|
||||
func (c *Config) LoadHandlers(store *Store, strategy jwt.Signer) {
|
||||
validator := openid.NewOpenIDConnectRequestValidator(strategy, c)
|
||||
|
||||
|
@ -278,6 +292,10 @@ func (c *Config) LoadHandlers(store *Store, strategy jwt.Signer) {
|
|||
if h, ok := handler.(fosite.RevocationHandler); ok {
|
||||
x.Revocation.Append(h)
|
||||
}
|
||||
|
||||
if h, ok := handler.(fosite.PushedAuthorizeEndpointHandler); ok {
|
||||
x.PushedAuthorizeEndpoint.Append(h)
|
||||
}
|
||||
}
|
||||
|
||||
c.Handlers = x
|
||||
|
@ -533,7 +551,7 @@ func (c *Config) GetTokenURL(ctx context.Context) (tokenURL string) {
|
|||
// GetSecretsHasher returns the client secrets hashing function.
|
||||
func (c *Config) GetSecretsHasher(ctx context.Context) (hasher fosite.Hasher) {
|
||||
if c.Hash.ClientSecrets == nil {
|
||||
c.Hash.ClientSecrets, _ = NewAdaptiveHasher()
|
||||
c.Hash.ClientSecrets, _ = NewHasher()
|
||||
}
|
||||
|
||||
return c.Hash.ClientSecrets
|
||||
|
@ -595,7 +613,7 @@ func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool {
|
|||
|
||||
// GetPushedAuthorizeContextLifespan is the lifespan of the short-lived PAR context.
|
||||
func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) (lifespan time.Duration) {
|
||||
if c.PAR.ContextLifespan == 0 {
|
||||
if c.PAR.ContextLifespan.Seconds() == 0 {
|
||||
c.PAR.ContextLifespan = lifespanPARContextDefault
|
||||
}
|
||||
|
||||
|
|
|
@ -110,6 +110,12 @@ const (
|
|||
PKCEChallengeMethodSHA256 = "S256"
|
||||
)
|
||||
|
||||
const (
|
||||
FormParameterRequestURI = "request_uri"
|
||||
FormParameterCodeChallenge = "code_challenge"
|
||||
FormParameterCodeChallengeMethod = "code_challenge_method"
|
||||
)
|
||||
|
||||
// Endpoints.
|
||||
const (
|
||||
EndpointAuthorization = "authorization"
|
||||
|
@ -117,6 +123,7 @@ const (
|
|||
EndpointUserinfo = "userinfo"
|
||||
EndpointIntrospection = "introspection"
|
||||
EndpointRevocation = "revocation"
|
||||
EndpointPushedAuthorizationRequest = "pushed-authorization-request"
|
||||
)
|
||||
|
||||
// JWT Headers.
|
||||
|
@ -126,7 +133,9 @@ const (
|
|||
)
|
||||
|
||||
const (
|
||||
tokenPrefixFmt = "authelia_%s_" //nolint:gosec
|
||||
tokenPrefixOrgAutheliaFmt = "authelia_%s_" //nolint:gosec
|
||||
tokenPrefixOrgOryFmt = "ory_%s_" //nolint:gosec
|
||||
|
||||
tokenPrefixPartAccessToken = "at"
|
||||
tokenPrefixPartRefreshToken = "rt"
|
||||
tokenPrefixPartAuthorizeCode = "ac"
|
||||
|
@ -146,6 +155,8 @@ const (
|
|||
EndpointPathUserinfo = EndpointPathRoot + "/" + EndpointUserinfo
|
||||
EndpointPathIntrospection = EndpointPathRoot + "/" + EndpointIntrospection
|
||||
EndpointPathRevocation = EndpointPathRoot + "/" + EndpointRevocation
|
||||
|
||||
EndpointPathPushedAuthorizationRequest = EndpointPathRoot + "/" + EndpointPushedAuthorizationRequest
|
||||
)
|
||||
|
||||
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176
|
||||
|
|
|
@ -19,7 +19,6 @@ type HMACCoreStrategy struct {
|
|||
fosite.RefreshTokenLifespanProvider
|
||||
fosite.AuthorizeCodeLifespanProvider
|
||||
}
|
||||
prefix string
|
||||
}
|
||||
|
||||
// AccessTokenSignature implements oauth2.AccessTokenStrategy.
|
||||
|
@ -112,11 +111,11 @@ func (h *HMACCoreStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.R
|
|||
}
|
||||
|
||||
func (h *HMACCoreStrategy) getPrefix(part string) string {
|
||||
if len(h.prefix) == 0 {
|
||||
return ""
|
||||
}
|
||||
return h.getCustomPrefix(tokenPrefixOrgAutheliaFmt, part)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(h.prefix, part)
|
||||
func (h *HMACCoreStrategy) getCustomPrefix(tokenPrefixFmt, part string) string {
|
||||
return fmt.Sprintf(tokenPrefixFmt, part)
|
||||
}
|
||||
|
||||
func (h *HMACCoreStrategy) setPrefix(token, part string) string {
|
||||
|
@ -124,5 +123,9 @@ func (h *HMACCoreStrategy) setPrefix(token, part string) string {
|
|||
}
|
||||
|
||||
func (h *HMACCoreStrategy) trimPrefix(token, part string) string {
|
||||
if strings.HasPrefix(token, h.getCustomPrefix(tokenPrefixOrgOryFmt, part)) {
|
||||
return strings.TrimPrefix(token, h.getCustomPrefix(tokenPrefixOrgOryFmt, part))
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(token, h.getPrefix(part))
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHMACCoreStrategy_TrimPrefix(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
part string
|
||||
expected string
|
||||
}{
|
||||
{"ShouldTrimAutheliaPrefix", "authelia_at_example", tokenPrefixPartAccessToken, "example"},
|
||||
{"ShouldTrimOryPrefix", "ory_at_example", tokenPrefixPartAccessToken, "example"},
|
||||
{"ShouldTrimOnlyAutheliaPrefix", "authelia_at_ory_at_example", tokenPrefixPartAccessToken, "ory_at_example"},
|
||||
{"ShouldTrimOnlyOryPrefix", "ory_at_authelia_at_example", tokenPrefixPartAccessToken, "authelia_at_example"},
|
||||
{"ShouldNotTrimGitHubPrefix", "gh_at_example", tokenPrefixPartAccessToken, "gh_at_example"},
|
||||
}
|
||||
|
||||
strategy := &HMACCoreStrategy{}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, strategy.trimPrefix(tc.have, tc.part))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHMACCoreStrategy_GetSetPrefix(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
expectedSet string
|
||||
expectedGet string
|
||||
}{
|
||||
{"ShouldAddPrefix", "example", "authelia_%s_example", "authelia_%s_"},
|
||||
}
|
||||
|
||||
strategy := &HMACCoreStrategy{}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, part := range []string{tokenPrefixPartAccessToken, tokenPrefixPartAuthorizeCode, tokenPrefixPartRefreshToken} {
|
||||
t.Run(strings.ToUpper(part), func(t *testing.T) {
|
||||
assert.Equal(t, fmt.Sprintf(tc.expectedSet, part), strategy.setPrefix(tc.have, part))
|
||||
assert.Equal(t, fmt.Sprintf(tc.expectedGet, part), strategy.getPrefix(part))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,7 +1,11 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration.
|
||||
func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clients map[string]*Client) (config OpenIDConnectWellKnownConfiguration) {
|
||||
func NewOpenIDConnectWellKnownConfiguration(c *schema.OpenIDConnectConfiguration, clients map[string]*Client) (config OpenIDConnectWellKnownConfiguration) {
|
||||
config = OpenIDConnectWellKnownConfiguration{
|
||||
CommonDiscoveryOptions: CommonDiscoveryOptions{
|
||||
SubjectTypesSupported: []string{
|
||||
|
@ -78,6 +82,9 @@ func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clien
|
|||
SigningAlgorithmRSAWithSHA256,
|
||||
},
|
||||
},
|
||||
PushedAuthorizationDiscoveryOptions: PushedAuthorizationDiscoveryOptions{
|
||||
RequirePushedAuthorizationRequests: c.PAR.Enforce,
|
||||
},
|
||||
}
|
||||
|
||||
var pairwise, public bool
|
||||
|
@ -96,7 +103,7 @@ func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge bool, clien
|
|||
config.SubjectTypesSupported = append(config.SubjectTypesSupported, SubjectTypePairwise)
|
||||
}
|
||||
|
||||
if enablePKCEPlainChallenge {
|
||||
if c.EnablePKCEPlainChallenge {
|
||||
config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, PKCEChallengeMethodPlain)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,12 +4,15 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
pkcePlainChallenge bool
|
||||
enforcePAR bool
|
||||
clients map[string]*Client
|
||||
|
||||
expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string
|
||||
|
@ -63,7 +66,14 @@ func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
actual := NewOpenIDConnectWellKnownConfiguration(tc.pkcePlainChallenge, tc.clients)
|
||||
c := schema.OpenIDConnectConfiguration{
|
||||
EnablePKCEPlainChallenge: tc.pkcePlainChallenge,
|
||||
PAR: schema.OpenIDConnectPARConfiguration{
|
||||
Enforce: tc.enforcePAR,
|
||||
},
|
||||
}
|
||||
|
||||
actual := NewOpenIDConnectWellKnownConfiguration(&c, tc.clients)
|
||||
for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported {
|
||||
assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod)
|
||||
}
|
||||
|
|
|
@ -9,11 +9,27 @@ import (
|
|||
var errPasswordsDoNotMatch = errors.New("the passwords don't match")
|
||||
|
||||
var (
|
||||
// ErrIssuerCouldNotDerive is sent when the issuer couldn't be determined from the headers.
|
||||
ErrIssuerCouldNotDerive = fosite.ErrServerError.WithHint("Could not safely derive the issuer.")
|
||||
|
||||
// ErrSubjectCouldNotLookup is sent when the Subject Identifier for a user couldn't be generated or obtained from the database.
|
||||
ErrSubjectCouldNotLookup = fosite.ErrServerError.WithHint("Could not lookup user subject.")
|
||||
|
||||
// ErrConsentCouldNotPerform is sent when the Consent Session couldn't be performed for varying reasons.
|
||||
ErrConsentCouldNotPerform = fosite.ErrServerError.WithHint("Could not perform consent.")
|
||||
|
||||
// ErrConsentCouldNotGenerate is sent when the Consent Session failed to be generated for some reason, usually a failed UUIDv4 generation.
|
||||
ErrConsentCouldNotGenerate = fosite.ErrServerError.WithHint("Could not generate the consent session.")
|
||||
|
||||
// ErrConsentCouldNotSave is sent when the Consent Session couldn't be saved to the database.
|
||||
ErrConsentCouldNotSave = fosite.ErrServerError.WithHint("Could not save the consent session.")
|
||||
|
||||
// ErrConsentCouldNotLookup is sent when the Consent ID is not a known UUID.
|
||||
ErrConsentCouldNotLookup = fosite.ErrServerError.WithHint("Failed to lookup the consent session.")
|
||||
|
||||
// ErrConsentMalformedChallengeID is sent when the Consent ID is not a UUID.
|
||||
ErrConsentMalformedChallengeID = fosite.ErrServerError.WithHint("Malformed consent session challenge ID.")
|
||||
|
||||
// ErrPAREnforcedClientMissingPAR is sent when a client has EnforcePAR configured but the Authorization Request was not Pushed.
|
||||
ErrPAREnforcedClientMissingPAR = fosite.ErrInvalidRequest.WithHint("Pushed Authorization Requests are enforced for this client but no such request was sent.")
|
||||
)
|
||||
|
|
|
@ -8,8 +8,9 @@ import (
|
|||
"github.com/go-crypt/crypt/algorithm/plaintext"
|
||||
)
|
||||
|
||||
func NewAdaptiveHasher() (hasher *AdaptiveHasher, err error) {
|
||||
hasher = &AdaptiveHasher{}
|
||||
// NewHasher returns a new Hasher.
|
||||
func NewHasher() (hasher *Hasher, err error) {
|
||||
hasher = &Hasher{}
|
||||
|
||||
if hasher.decoder, err = crypt.NewDefaultDecoder(); err != nil {
|
||||
return nil, err
|
||||
|
@ -22,13 +23,13 @@ func NewAdaptiveHasher() (hasher *AdaptiveHasher, err error) {
|
|||
return hasher, nil
|
||||
}
|
||||
|
||||
// AdaptiveHasher implements the fosite.Hasher interface without an actual hashing algo.
|
||||
type AdaptiveHasher struct {
|
||||
// Hasher implements the fosite.Hasher interface and adaptively compares hashes.
|
||||
type Hasher struct {
|
||||
decoder algorithm.DecoderRegister
|
||||
}
|
||||
|
||||
// Compare compares the hash with the data and returns an error if they don't match.
|
||||
func (h *AdaptiveHasher) Compare(_ context.Context, hash, data []byte) (err error) {
|
||||
func (h Hasher) Compare(_ context.Context, hash, data []byte) (err error) {
|
||||
var digest algorithm.Digest
|
||||
|
||||
if digest, err = h.decoder.Decode(string(hash)); err != nil {
|
||||
|
@ -43,6 +44,6 @@ func (h *AdaptiveHasher) Compare(_ context.Context, hash, data []byte) (err erro
|
|||
}
|
||||
|
||||
// Hash creates a new hash from data.
|
||||
func (h *AdaptiveHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
|
||||
func (h Hasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
|
||||
return data, nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
|
||||
hasher, err := NewAdaptiveHasher()
|
||||
hasher, err := NewHasher()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -22,7 +22,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldNotRaiseErrorOnEqualPasswordsPlainTextWithSeparator(t *testing.T) {
|
||||
hasher, err := NewAdaptiveHasher()
|
||||
hasher, err := NewHasher()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -35,7 +35,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainTextWithSeparator(t *testing.T)
|
|||
}
|
||||
|
||||
func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
|
||||
hasher, err := NewAdaptiveHasher()
|
||||
hasher, err := NewHasher()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -48,7 +48,7 @@ func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldHashPassword(t *testing.T) {
|
||||
hasher := AdaptiveHasher{}
|
||||
hasher := Hasher{}
|
||||
|
||||
data := []byte("abc")
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ func NewOpenIDConnectProvider(config *schema.OpenIDConnectConfiguration, store s
|
|||
|
||||
provider.Config.LoadHandlers(provider.Store, provider.KeyManager.Strategy())
|
||||
|
||||
provider.discovery = NewOpenIDConnectWellKnownConfiguration(config.EnablePKCEPlainChallenge, provider.Store.clients)
|
||||
provider.discovery = NewOpenIDConnectWellKnownConfiguration(config, provider.Store.clients)
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
@ -50,12 +50,12 @@ func (p *OpenIDConnectProvider) GetOAuth2WellKnownConfiguration(issuer string) O
|
|||
}
|
||||
|
||||
options.Issuer = issuer
|
||||
|
||||
options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs)
|
||||
|
||||
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
|
||||
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
|
||||
|
||||
options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization)
|
||||
options.PushedAuthorizationRequestEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathPushedAuthorizationRequest)
|
||||
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
|
||||
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
|
||||
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation)
|
||||
|
||||
return options
|
||||
|
@ -72,14 +72,14 @@ func (p *OpenIDConnectProvider) GetOpenIDConnectWellKnownConfiguration(issuer st
|
|||
}
|
||||
|
||||
options.Issuer = issuer
|
||||
|
||||
options.JWKSURI = fmt.Sprintf("%s%s", issuer, EndpointPathJWKs)
|
||||
|
||||
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
|
||||
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
|
||||
|
||||
options.AuthorizationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathAuthorization)
|
||||
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation)
|
||||
options.PushedAuthorizationRequestEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathPushedAuthorizationRequest)
|
||||
options.TokenEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathToken)
|
||||
options.UserinfoEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathUserinfo)
|
||||
options.IntrospectionEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathIntrospection)
|
||||
options.RevocationEndpoint = fmt.Sprintf("%s%s", issuer, EndpointPathRevocation)
|
||||
|
||||
return options
|
||||
}
|
||||
|
|
|
@ -165,7 +165,7 @@ func (s *Store) InvalidateAuthorizeCodeSession(ctx context.Context, code string)
|
|||
// This implements a portion of oauth2.AuthorizeCodeStorage.
|
||||
func (s *Store) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) {
|
||||
// TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions.
|
||||
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session)
|
||||
return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session)
|
||||
}
|
||||
|
||||
// CreateAccessTokenSession stores the authorization request for a given access token.
|
||||
|
@ -190,7 +190,7 @@ func (s *Store) RevokeAccessToken(ctx context.Context, requestID string) (err er
|
|||
// GetAccessTokenSession gets the authorization request for a given access token.
|
||||
// This implements a portion of oauth2.AccessTokenStorage.
|
||||
func (s *Store) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
|
||||
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session)
|
||||
return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session)
|
||||
}
|
||||
|
||||
// CreateRefreshTokenSession stores the authorization request for a given refresh token.
|
||||
|
@ -223,7 +223,7 @@ func (s *Store) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestI
|
|||
// GetRefreshTokenSession gets the authorization request for a given refresh token.
|
||||
// This implements a portion of oauth2.RefreshTokenStorage.
|
||||
func (s *Store) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
|
||||
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session)
|
||||
return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session)
|
||||
}
|
||||
|
||||
// CreatePKCERequestSession stores the authorization request for a given PKCE request.
|
||||
|
@ -241,7 +241,7 @@ func (s *Store) DeletePKCERequestSession(ctx context.Context, signature string)
|
|||
// GetPKCERequestSession gets the authorization request for a given PKCE request.
|
||||
// This implements a portion of pkce.PKCERequestStorage.
|
||||
func (s *Store) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) {
|
||||
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session)
|
||||
return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session)
|
||||
}
|
||||
|
||||
// CreateOpenIDConnectSession creates an open id connect session for a given authorize code.
|
||||
|
@ -263,7 +263,37 @@ func (s *Store) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode st
|
|||
// - or an arbitrary error if an error occurred.
|
||||
// This implements a portion of openid.OpenIDConnectRequestStorage.
|
||||
func (s *Store) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) {
|
||||
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession())
|
||||
return s.loadRequesterBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession())
|
||||
}
|
||||
|
||||
// CreatePARSession stores the pushed authorization request context. The requestURI is used to derive the key.
|
||||
// This implements a portion of fosite.PARStorage.
|
||||
func (s *Store) CreatePARSession(ctx context.Context, requestURI string, request fosite.AuthorizeRequester) (err error) {
|
||||
var par *model.OAuth2PARContext
|
||||
|
||||
if par, err = model.NewOAuth2PARContext(requestURI, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.provider.SaveOAuth2PARContext(ctx, *par)
|
||||
}
|
||||
|
||||
// GetPARSession gets the push authorization request context. The caller is expected to merge the AuthorizeRequest.
|
||||
// This implements a portion of fosite.PARStorage.
|
||||
func (s *Store) GetPARSession(ctx context.Context, requestURI string) (request fosite.AuthorizeRequester, err error) {
|
||||
var par *model.OAuth2PARContext
|
||||
|
||||
if par, err = s.provider.LoadOAuth2PARContext(ctx, requestURI); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return par.ToAuthorizeRequest(ctx, NewSession(), s)
|
||||
}
|
||||
|
||||
// DeletePARSession deletes the context.
|
||||
// This implements a portion of fosite.PARStorage.
|
||||
func (s *Store) DeletePARSession(ctx context.Context, requestURI string) (err error) {
|
||||
return s.provider.RevokeOAuth2PARContext(ctx, requestURI)
|
||||
}
|
||||
|
||||
// IsJWTUsed implements an interface required for RFC7523.
|
||||
|
@ -280,7 +310,7 @@ func (s *Store) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Tim
|
|||
return s.SetClientAssertionJWT(ctx, jti, exp)
|
||||
}
|
||||
|
||||
func (s *Store) loadSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) {
|
||||
func (s *Store) loadRequesterBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) {
|
||||
var (
|
||||
sessionModel *model.OAuth2Session
|
||||
)
|
||||
|
|
|
@ -119,6 +119,8 @@ type Client struct {
|
|||
ResponseTypes []string
|
||||
ResponseModes []fosite.ResponseModeType
|
||||
|
||||
EnforcePAR bool
|
||||
|
||||
UserinfoSigningAlgorithm string
|
||||
|
||||
Policy authorization.Level
|
||||
|
|
|
@ -331,6 +331,15 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
|||
r.GET("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
|
||||
r.POST("/api/oidc/authorize", policyCORSAuthorization.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorization))))
|
||||
|
||||
policyCORSPAR := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSliceFold(oidc.EndpointPushedAuthorizationRequest, config.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.HandleOnlyOPTIONS)
|
||||
r.POST(oidc.EndpointPathPushedAuthorizationRequest, policyCORSPAR.Middleware(bridgeOIDC(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectPushedAuthorizationRequest))))
|
||||
|
||||
policyCORSToken := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowCredentials(true).
|
||||
WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodPost).
|
||||
|
|
|
@ -13,15 +13,16 @@ const (
|
|||
tableUserPreferences = "user_preferences"
|
||||
tableWebauthnDevices = "webauthn_devices"
|
||||
|
||||
tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
|
||||
tableOAuth2ConsentSession = "oauth2_consent_session"
|
||||
tableOAuth2ConsentPreConfiguration = "oauth2_consent_preconfiguration"
|
||||
|
||||
tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
|
||||
tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
|
||||
tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
|
||||
tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
|
||||
tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
|
||||
tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
|
||||
tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
|
||||
tableOAuth2PARContext = "oauth2_par_context"
|
||||
tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
|
||||
tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
|
||||
|
||||
tableMigrations = "migrations"
|
||||
tableEncryption = "encryption"
|
||||
|
@ -32,26 +33,29 @@ type OAuth2SessionType int
|
|||
|
||||
// Representation of specific OAuth 2.0 session types.
|
||||
const (
|
||||
OAuth2SessionTypeAuthorizeCode OAuth2SessionType = iota
|
||||
OAuth2SessionTypeAccessToken
|
||||
OAuth2SessionTypeRefreshToken
|
||||
OAuth2SessionTypePKCEChallenge
|
||||
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
|
||||
OAuth2SessionTypeAuthorizeCode
|
||||
OAuth2SessionTypeOpenIDConnect
|
||||
OAuth2SessionTypePAR
|
||||
OAuth2SessionTypePKCEChallenge
|
||||
OAuth2SessionTypeRefreshToken
|
||||
)
|
||||
|
||||
// String returns a string representation of this OAuth2SessionType.
|
||||
func (s OAuth2SessionType) String() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return "authorization code"
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return "access token"
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return "refresh token"
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return "pkce challenge"
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return "authorization code"
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return "openid connect"
|
||||
case OAuth2SessionTypePAR:
|
||||
return "pushed authorization request context"
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return "pkce challenge"
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return "refresh token"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
|
@ -60,16 +64,18 @@ func (s OAuth2SessionType) String() string {
|
|||
// Table returns the table name for this session type.
|
||||
func (s OAuth2SessionType) Table() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return tableOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return tableOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return tableOAuth2RefreshTokenSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return tableOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return tableOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return tableOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePAR:
|
||||
return tableOAuth2PARContext
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return tableOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return tableOAuth2RefreshTokenSession
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
@ -119,7 +125,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
|
||||
reMigration = regexp.MustCompile(`^V(?P<Version>\d{4})\.(?P<Name>[^.]+)\.(?P<Provider>(all|sqlite|postgres|mysql))\.(?P<Direction>(up|down))\.sql$`)
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -130,15 +130,15 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
|
|||
}
|
||||
|
||||
func scanMigration(m string) (migration model.SchemaMigration, err error) {
|
||||
result := reMigration.FindStringSubmatch(m)
|
||||
|
||||
if result == nil || len(result) != 5 {
|
||||
if !reMigration.MatchString(m) {
|
||||
return model.SchemaMigration{}, errors.New("invalid migration: could not parse the format")
|
||||
}
|
||||
|
||||
result := reMigration.FindStringSubmatch(m)
|
||||
|
||||
migration = model.SchemaMigration{
|
||||
Name: strings.ReplaceAll(result[2], "_", " "),
|
||||
Provider: result[3],
|
||||
Name: strings.ReplaceAll(result[reMigration.SubexpIndex("Name")], "_", " "),
|
||||
Provider: result[reMigration.SubexpIndex("Provider")],
|
||||
}
|
||||
|
||||
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
|
||||
|
@ -148,22 +148,22 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) {
|
|||
|
||||
migration.Query = string(data)
|
||||
|
||||
switch result[4] {
|
||||
switch direction := result[reMigration.SubexpIndex("Direction")]; direction {
|
||||
case "up":
|
||||
migration.Up = true
|
||||
case "down":
|
||||
migration.Up = false
|
||||
default:
|
||||
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4])
|
||||
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in Direction group '%s' must be up or down", direction)
|
||||
}
|
||||
|
||||
migration.Version, _ = strconv.Atoi(result[1])
|
||||
migration.Version, _ = strconv.Atoi(result[reMigration.SubexpIndex("Version")])
|
||||
|
||||
switch migration.Provider {
|
||||
case providerAll, providerSQLite, providerMySQL, providerPostgres:
|
||||
break
|
||||
default:
|
||||
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3])
|
||||
return model.SchemaMigration{}, fmt.Errorf("invalid migration: value in Provider group '%s' must be all, sqlite, postgres, or mysql", migration.Provider)
|
||||
}
|
||||
|
||||
return migration, nil
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS oauth2_par_context;
|
|
@ -0,0 +1,17 @@
|
|||
CREATE TABLE IF NOT EXISTS oauth2_par_context (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||
request_id VARCHAR(40) NOT NULL,
|
||||
client_id VARCHAR(255) NOT NULL,
|
||||
signature VARCHAR(255) NOT NULL,
|
||||
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
scopes TEXT NOT NULL,
|
||||
audience TEXT NOT NULL,
|
||||
handled_response_types TEXT NOT NULL,
|
||||
response_mode TEXT NOT NULL,
|
||||
response_mode_default TEXT NOT NULL,
|
||||
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
form_data TEXT NOT NULL,
|
||||
session_data BLOB NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
|
||||
|
||||
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);
|
|
@ -0,0 +1,17 @@
|
|||
CREATE TABLE IF NOT EXISTS oauth2_par_context (
|
||||
id SERIAL CONSTRAINT oauth2_par_context_pkey PRIMARY KEY,
|
||||
request_id VARCHAR(40) NOT NULL,
|
||||
client_id VARCHAR(255) NOT NULL,
|
||||
signature VARCHAR(255) NOT NULL,
|
||||
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
scopes TEXT NOT NULL,
|
||||
audience TEXT NULL DEFAULT '',
|
||||
handled_response_types TEXT NOT NULL DEFAULT '',
|
||||
response_mode TEXT NOT NULL DEFAULT '',
|
||||
response_mode_default TEXT NOT NULL DEFAULT '',
|
||||
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
form_data TEXT NOT NULL,
|
||||
session_data BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);
|
|
@ -0,0 +1,17 @@
|
|||
CREATE TABLE IF NOT EXISTS oauth2_par_context (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
signature VARCHAR(255) NOT NULL,
|
||||
request_id VARCHAR(40) NOT NULL,
|
||||
client_id VARCHAR(255) NOT NULL,
|
||||
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
scopes TEXT NOT NULL,
|
||||
audience TEXT NOT NULL,
|
||||
handled_response_types TEXT NOT NULL,
|
||||
response_mode TEXT NOT NULL,
|
||||
response_mode_default TEXT NOT NULL,
|
||||
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
form_data TEXT NOT NULL,
|
||||
session_data BLOB NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX oauth2_par_context_signature_key ON oauth2_par_context (signature);
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
const (
|
||||
// This is the latest schema version for the purpose of tests.
|
||||
LatestVersion = 7
|
||||
LatestVersion = 8
|
||||
)
|
||||
|
||||
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
||||
|
|
|
@ -24,8 +24,8 @@ type Provider interface {
|
|||
LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error)
|
||||
|
||||
SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error)
|
||||
LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (subject *model.UserOpaqueIdentifier, err error)
|
||||
LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error)
|
||||
LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error)
|
||||
LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error)
|
||||
LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error)
|
||||
|
||||
SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error)
|
||||
|
@ -65,6 +65,10 @@ type Provider interface {
|
|||
DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
|
||||
LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error)
|
||||
|
||||
SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error)
|
||||
LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error)
|
||||
RevokeOAuth2PARContext(ctx context.Context, signature string) (err error)
|
||||
|
||||
SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error)
|
||||
LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error)
|
||||
|
||||
|
|
|
@ -70,6 +70,13 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier),
|
||||
sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
|
||||
|
||||
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
||||
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
||||
|
||||
sqlInsertOAuth2PARContext: fmt.Sprintf(queryFmtInsertOAuth2PARContext, tableOAuth2PARContext),
|
||||
sqlSelectOAuth2PARContext: fmt.Sprintf(queryFmtSelectOAuth2PARContext, tableOAuth2PARContext),
|
||||
sqlRevokeOAuth2PARContext: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PARContext),
|
||||
|
||||
sqlInsertOAuth2ConsentPreConfiguration: fmt.Sprintf(queryFmtInsertOAuth2ConsentPreConfiguration, tableOAuth2ConsentPreConfiguration),
|
||||
sqlSelectOAuth2ConsentPreConfigurations: fmt.Sprintf(queryFmtSelectOAuth2ConsentPreConfigurations, tableOAuth2ConsentPreConfiguration),
|
||||
|
||||
|
@ -79,13 +86,6 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
|
||||
sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
|
||||
|
||||
sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
||||
sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
||||
|
||||
sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
|
||||
sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
|
||||
sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
|
||||
|
@ -93,19 +93,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
|
||||
sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
|
||||
|
||||
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
||||
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
||||
|
||||
sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
||||
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
||||
sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
||||
sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
||||
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
||||
|
||||
sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
|
||||
sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
|
||||
|
@ -114,8 +107,19 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
|
||||
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
|
||||
|
||||
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
||||
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
||||
sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
||||
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
|
||||
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
||||
|
||||
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
||||
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
|
||||
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
||||
|
||||
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
|
||||
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
|
||||
|
@ -224,13 +228,18 @@ type SQLProvider struct {
|
|||
sqlDeactivateOAuth2AccessTokenSession string
|
||||
sqlDeactivateOAuth2AccessTokenSessionByRequestID string
|
||||
|
||||
// Table: oauth2_refresh_token_session.
|
||||
sqlInsertOAuth2RefreshTokenSession string
|
||||
sqlSelectOAuth2RefreshTokenSession string
|
||||
sqlRevokeOAuth2RefreshTokenSession string
|
||||
sqlRevokeOAuth2RefreshTokenSessionByRequestID string
|
||||
sqlDeactivateOAuth2RefreshTokenSession string
|
||||
sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
|
||||
// Table: oauth2_openid_connect_session.
|
||||
sqlInsertOAuth2OpenIDConnectSession string
|
||||
sqlSelectOAuth2OpenIDConnectSession string
|
||||
sqlRevokeOAuth2OpenIDConnectSession string
|
||||
sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
|
||||
sqlDeactivateOAuth2OpenIDConnectSession string
|
||||
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
|
||||
|
||||
// Table: oauth2_par_context.
|
||||
sqlInsertOAuth2PARContext string
|
||||
sqlSelectOAuth2PARContext string
|
||||
sqlRevokeOAuth2PARContext string
|
||||
|
||||
// Table: oauth2_pkce_request_session.
|
||||
sqlInsertOAuth2PKCERequestSession string
|
||||
|
@ -240,13 +249,13 @@ type SQLProvider struct {
|
|||
sqlDeactivateOAuth2PKCERequestSession string
|
||||
sqlDeactivateOAuth2PKCERequestSessionByRequestID string
|
||||
|
||||
// Table: oauth2_openid_connect_session.
|
||||
sqlInsertOAuth2OpenIDConnectSession string
|
||||
sqlSelectOAuth2OpenIDConnectSession string
|
||||
sqlRevokeOAuth2OpenIDConnectSession string
|
||||
sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
|
||||
sqlDeactivateOAuth2OpenIDConnectSession string
|
||||
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
|
||||
// Table: oauth2_refresh_token_session.
|
||||
sqlInsertOAuth2RefreshTokenSession string
|
||||
sqlSelectOAuth2RefreshTokenSession string
|
||||
sqlRevokeOAuth2RefreshTokenSession string
|
||||
sqlRevokeOAuth2RefreshTokenSessionByRequestID string
|
||||
sqlDeactivateOAuth2RefreshTokenSession string
|
||||
sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
|
||||
|
||||
sqlUpsertOAuth2BlacklistedJTI string
|
||||
sqlSelectOAuth2BlacklistedJTI string
|
||||
|
@ -339,19 +348,19 @@ func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
|
||||
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil {
|
||||
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err)
|
||||
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, subject.Service, subject.SectorID, subject.Username, subject.Identifier); err != nil {
|
||||
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", subject.Username, subject.Identifier.String(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
||||
opaqueID = &model.UserOpaqueIdentifier{}
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, identifier uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) {
|
||||
subject = &model.UserOpaqueIdentifier{}
|
||||
|
||||
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil {
|
||||
if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifier, identifier); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, nil
|
||||
|
@ -360,11 +369,11 @@ func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID u
|
|||
}
|
||||
}
|
||||
|
||||
return opaqueID, nil
|
||||
return subject, nil
|
||||
}
|
||||
|
||||
// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database.
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) {
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (identifiers []model.UserOpaqueIdentifier, err error) {
|
||||
var rows *sqlx.Rows
|
||||
|
||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil {
|
||||
|
@ -380,17 +389,17 @@ func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs
|
|||
return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err)
|
||||
}
|
||||
|
||||
opaqueIDs = append(opaqueIDs, *opaqueID)
|
||||
identifiers = append(identifiers, *opaqueID)
|
||||
}
|
||||
|
||||
return opaqueIDs, nil
|
||||
return identifiers, nil
|
||||
}
|
||||
|
||||
// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
||||
opaqueID = &model.UserOpaqueIdentifier{}
|
||||
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) {
|
||||
subject = &model.UserOpaqueIdentifier{}
|
||||
|
||||
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
|
||||
if err = p.db.GetContext(ctx, subject, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, nil
|
||||
|
@ -399,7 +408,7 @@ func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, s
|
|||
}
|
||||
}
|
||||
|
||||
return opaqueID, nil
|
||||
return subject, nil
|
||||
}
|
||||
|
||||
// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session.
|
||||
|
@ -496,22 +505,22 @@ func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2S
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlInsertOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlInsertOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlInsertOAuth2RefreshTokenSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlInsertOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlInsertOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlInsertOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlInsertOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlInsertOAuth2RefreshTokenSession
|
||||
default:
|
||||
return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
|
||||
}
|
||||
|
||||
if session.Session, err = p.encrypt(session.Session); err != nil {
|
||||
return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
|
||||
return fmt.Errorf("error encrypting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
|
||||
}
|
||||
|
||||
_, err = p.db.ExecContext(ctx, query,
|
||||
|
@ -532,16 +541,16 @@ func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlRevokeOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlRevokeOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlRevokeOAuth2RefreshTokenSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlRevokeOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlRevokeOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlRevokeOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlRevokeOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlRevokeOAuth2RefreshTokenSession
|
||||
default:
|
||||
return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
|
||||
}
|
||||
|
@ -558,16 +567,16 @@ func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessio
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
|
||||
default:
|
||||
return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
|
||||
}
|
||||
|
@ -584,16 +593,16 @@ func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType O
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlDeactivateOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlDeactivateOAuth2RefreshTokenSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlDeactivateOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlDeactivateOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlDeactivateOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlDeactivateOAuth2RefreshTokenSession
|
||||
default:
|
||||
return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType.String())
|
||||
}
|
||||
|
@ -610,16 +619,16 @@ func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, se
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
|
||||
default:
|
||||
return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType.String())
|
||||
}
|
||||
|
@ -636,16 +645,16 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
|
|||
var query string
|
||||
|
||||
switch sessionType {
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlSelectOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
query = p.sqlSelectOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlSelectOAuth2RefreshTokenSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlSelectOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
query = p.sqlSelectOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
query = p.sqlSelectOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
query = p.sqlSelectOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
query = p.sqlSelectOAuth2RefreshTokenSession
|
||||
default:
|
||||
return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType.String())
|
||||
}
|
||||
|
@ -663,6 +672,45 @@ func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2S
|
|||
return session, nil
|
||||
}
|
||||
|
||||
// SaveOAuth2PARContext save a OAuth2PARContext to the database.
|
||||
func (p *SQLProvider) SaveOAuth2PARContext(ctx context.Context, par model.OAuth2PARContext) (err error) {
|
||||
if par.Session, err = p.encrypt(par.Session); err != nil {
|
||||
return fmt.Errorf("error encrypting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2PARContext,
|
||||
par.Signature, par.RequestID, par.ClientID, par.RequestedAt, par.Scopes, par.Audience, par.HandledResponseTypes,
|
||||
par.ResponseMode, par.DefaultResponseMode, par.Revoked, par.Form, par.Session); err != nil {
|
||||
return fmt.Errorf("error inserting oauth2 pushed authorization request context data for with signature '%s' and request id '%s': %w", par.Signature, par.RequestID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadOAuth2PARContext loads a OAuth2PARContext from the database.
|
||||
func (p *SQLProvider) LoadOAuth2PARContext(ctx context.Context, signature string) (par *model.OAuth2PARContext, err error) {
|
||||
par = &model.OAuth2PARContext{}
|
||||
|
||||
if err = p.db.GetContext(ctx, par, p.sqlSelectOAuth2PARContext, signature); err != nil {
|
||||
return nil, fmt.Errorf("error selecting oauth2 pushed authorization request context with signature '%s': %w", signature, err)
|
||||
}
|
||||
|
||||
if par.Session, err = p.decrypt(par.Session); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting oauth2 oauth2 pushed authorization request context data with signature '%s' and request id '%s': %w", signature, par.RequestID, err)
|
||||
}
|
||||
|
||||
return par, nil
|
||||
}
|
||||
|
||||
// RevokeOAuth2PARContext marks a OAuth2PARContext as revoked in the database.
|
||||
func (p *SQLProvider) RevokeOAuth2PARContext(ctx context.Context, signature string) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOAuth2PARContext, signature); err != nil {
|
||||
return fmt.Errorf("error revoking oauth2 pushed authorization request context with signature '%s': %w", signature, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
|
||||
func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
|
||||
|
@ -762,7 +810,7 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
|
|||
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
|
||||
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
||||
if config.Secret, err = p.encrypt(config.Secret); err != nil {
|
||||
return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err)
|
||||
return fmt.Errorf("error encrypting TOTP configuration secret for user '%s': %w", config.Username, err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
|
||||
|
@ -806,7 +854,7 @@ func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string
|
|||
}
|
||||
|
||||
if config.Secret, err = p.decrypt(config.Secret); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err)
|
||||
return nil, fmt.Errorf("error decrypting TOTP secret for user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
|
@ -836,7 +884,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
|
|||
// SaveWebauthnDevice saves a registered Webauthn device.
|
||||
func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) {
|
||||
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
|
||||
return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
|
||||
return fmt.Errorf("error encrypting Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice,
|
||||
|
|
|
@ -87,13 +87,6 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
|
|||
provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted)
|
||||
provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID)
|
||||
|
||||
provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
|
||||
provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
|
||||
provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
|
||||
provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
|
||||
provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
|
||||
|
||||
provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession)
|
||||
provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession)
|
||||
provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID)
|
||||
|
@ -101,12 +94,23 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
|
|||
provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID)
|
||||
provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession)
|
||||
|
||||
provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession)
|
||||
provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession)
|
||||
provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession)
|
||||
provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID)
|
||||
provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession)
|
||||
provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
|
||||
provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
|
||||
provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
|
||||
provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
|
||||
provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
|
||||
|
||||
provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession)
|
||||
provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession)
|
||||
provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession)
|
||||
provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID)
|
||||
provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)
|
||||
|
||||
provider.sqlInsertOAuth2PARContext = provider.db.Rebind(provider.sqlInsertOAuth2PARContext)
|
||||
provider.sqlRevokeOAuth2PARContext = provider.db.Rebind(provider.sqlRevokeOAuth2PARContext)
|
||||
provider.sqlSelectOAuth2PARContext = provider.db.Rebind(provider.sqlSelectOAuth2PARContext)
|
||||
|
||||
provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession)
|
||||
provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession)
|
||||
|
@ -115,12 +119,12 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
|
|||
provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID)
|
||||
provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession)
|
||||
|
||||
provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession)
|
||||
provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession)
|
||||
provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession)
|
||||
provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID)
|
||||
provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)
|
||||
provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession)
|
||||
provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession)
|
||||
provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID)
|
||||
provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession)
|
||||
provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID)
|
||||
provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession)
|
||||
|
||||
provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI)
|
||||
|
||||
|
|
|
@ -314,6 +314,19 @@ const (
|
|||
SET active = FALSE
|
||||
WHERE request_id = ?;`
|
||||
|
||||
queryFmtSelectOAuth2PARContext = `
|
||||
SELECT id, signature, request_id, client_id, requested_at, scopes, audience,
|
||||
handled_response_types, response_mode, response_mode_default, revoked,
|
||||
form_data, session_data
|
||||
FROM %s
|
||||
WHERE signature = ? AND revoked = FALSE;`
|
||||
|
||||
queryFmtInsertOAuth2PARContext = `
|
||||
INSERT INTO %s (signature, request_id, client_id, requested_at, scopes, audience,
|
||||
handled_response_types, response_mode, response_mode_default, revoked,
|
||||
form_data, session_data)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
|
||||
|
||||
queryFmtSelectOAuth2BlacklistedJTI = `
|
||||
SELECT id, signature, expires_at
|
||||
FROM %s
|
||||
|
|
|
@ -1132,6 +1132,7 @@ func (s *CLISuite) TestStorage05ShouldChangeEncryptionKey() {
|
|||
s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_par_context): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (totp_configurations): FAILURE\n\t\tInvalid Rows: 4\n\t\tTotal Rows: 4\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
|
||||
|
@ -1149,6 +1150,7 @@ func (s *CLISuite) TestStorage05ShouldChangeEncryptionKey() {
|
|||
s.Assert().Contains(output, "\n\n\tTable (oauth2_openid_connect_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_pkce_request_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_refresh_token_session): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (oauth2_par_context): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (totp_configurations): SUCCESS\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 4\n")
|
||||
s.Assert().Contains(output, "\n\n\tTable (webauthn_devices): N/A\n\t\tInvalid Rows: 0\n\t\tTotal Rows: 0\n")
|
||||
|
||||
|
|
|
@ -104,8 +104,13 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool
|
|||
|
||||
// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles.
|
||||
func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) {
|
||||
return IsStringSliceContainsAnyF(needles, haystack, IsStringInSlice)
|
||||
}
|
||||
|
||||
// IsStringSliceContainsAnyF checks if the haystack contains any of the strings in the needles using the isInSlice func.
|
||||
func IsStringSliceContainsAnyF(needles []string, haystack []string, isInSlice func(needle string, haystack []string) bool) (inSlice bool) {
|
||||
for _, n := range needles {
|
||||
if IsStringInSlice(n, haystack) {
|
||||
if isInSlice(n, haystack) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue