feat(storage): encrypt u2f key (#2664)
Adds encryption to the U2F public keys. While the public keys cannot be used to authenticate, only to validate someone is authenticated, if a rogue operator changed these in the database they may be able to bypass 2FA. This prevents that.pull/2666/head
parent
104a61ecd6
commit
255aaeb2ad
|
@ -31,12 +31,12 @@ required: yes
|
||||||
{: .label .label-config .label-red }
|
{: .label .label-config .label-red }
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
The encryption key used to encrypt data in the database. It has a minimum length of 20 and must be provided. We encrypt
|
The encryption key used to encrypt data in the database. We encrypt data by creating a sha256 checksum of the provided
|
||||||
data by creating a sha256 checksum of the provided value, and use that to encrypt the data with the AES-GCM 256bit
|
value, and use that to encrypt the data with the AES-GCM 256bit algorithm.
|
||||||
algorithm.
|
|
||||||
|
|
||||||
The encrypted data in the database is as follows:
|
The minimum length of this key is 20 characters, however we generally recommend above 64 characters.
|
||||||
- TOTP Secret
|
|
||||||
|
See [securty measures](../../security/measures.md#storage-security-measures) for more information.
|
||||||
|
|
||||||
### local
|
### local
|
||||||
See [SQLite](./sqlite.md).
|
See [SQLite](./sqlite.md).
|
||||||
|
|
|
@ -81,6 +81,54 @@ LDAP implementations vary, so please ask if you need some assistance in configur
|
||||||
These protections can be [tuned](../configuration/authentication/ldap.md#refresh-interval) according to your security
|
These protections can be [tuned](../configuration/authentication/ldap.md#refresh-interval) according to your security
|
||||||
policy by changing refresh_interval, however we believe that 5 minutes is a fairly safe interval.
|
policy by changing refresh_interval, however we believe that 5 minutes is a fairly safe interval.
|
||||||
|
|
||||||
|
## Storage security measures
|
||||||
|
|
||||||
|
We force users to encrypt vulnerable data stored in the database. It is strongly advised you do not give this encryption
|
||||||
|
key to anyone. In the instance of a database installation that multiple users have access to, you should aim to ensure
|
||||||
|
that users who have access to the database do not also have access to this key.
|
||||||
|
|
||||||
|
The encrypted data in the database is as follows:
|
||||||
|
|
||||||
|
|Table |Column |Rational |
|
||||||
|
|:-----------------:|:--------:|:----------------------------------------------------------------------------------------------------:|
|
||||||
|
|totp_configurations|secret |Prevents a [Leaked Database](#leaked-database) or [Bad Actors](#bad-actors) from compromising security|
|
||||||
|
|u2f_devices |public_key|Prevents [Bad Actors](#bad-actors) from compromising security |
|
||||||
|
|
||||||
|
### Leaked Database
|
||||||
|
|
||||||
|
A leaked database can reasonably compromise security if there are credentials that are not encrypted. Columns encrypted
|
||||||
|
for this purpose prevent this attack vector.
|
||||||
|
|
||||||
|
### Bad Actors
|
||||||
|
|
||||||
|
A bad actor who has the SQL password and access to the database can theoretically change another users credential, this
|
||||||
|
theoretically bypasses authentication. Columns encrypted for this purpose prevent this attack vector.
|
||||||
|
|
||||||
|
A bad actor may also be able to use data in the database to bypass 2FA silently depending on the credentials. In the
|
||||||
|
instance of the U2F public key this is not possible, they can only change it which would eventually alert the user in
|
||||||
|
question. But in the case of TOTP they can use the secret to authenticate without knowledge of the user in question.
|
||||||
|
|
||||||
|
### Encryption key management
|
||||||
|
|
||||||
|
You must supply the encryption key in the recommended method of a [secret](../configuration/secrets.md) or in one of
|
||||||
|
the other methods available for [configuration](../configuration/index.md#configuration).
|
||||||
|
|
||||||
|
If you wish to change your encryption key for any reason you can do so using the following steps:
|
||||||
|
|
||||||
|
1. Run the `authelia --version` command to determine the version of Authelia you're running and either download that
|
||||||
|
version or run another container of that version interactively. All the subsequent commands assume you're running
|
||||||
|
the `authelia` binary in the current working directory. You will have to adjust this according to how you're running
|
||||||
|
it.
|
||||||
|
2. Run the `./authelia storage encryption change-key --help` command.
|
||||||
|
3. Stop Authelia.
|
||||||
|
- You can skip this step, however note that any data changed between the time you make the change and the time when
|
||||||
|
you stop Authelia i.e. via user registering a device; will be encrypted with the incorrect key.
|
||||||
|
4. Run the `./authelia storage encryption change-key` command with the appropriate parameters.
|
||||||
|
- The help from step 1 will be useful here. The easiest method to accomplish this is with the `--config`,
|
||||||
|
`--encryption-key`, and `--new-encryption-key` parameters.
|
||||||
|
5. Update the encryption key Authelia uses on startup.
|
||||||
|
6. Start Authelia.
|
||||||
|
|
||||||
## Notifier security measures (SMTP)
|
## Notifier security measures (SMTP)
|
||||||
|
|
||||||
The SMTP Notifier implementation does not allow connections that are not secure without changing default configuration
|
The SMTP Notifier implementation does not allow connections that are not secure without changing default configuration
|
||||||
|
|
|
@ -52,7 +52,6 @@ If properly configured, Authelia guarantees the following for security of your u
|
||||||
* Binding session cookies to single IP addresses.
|
* Binding session cookies to single IP addresses.
|
||||||
* Authenticate communication between Authelia and reverse proxy.
|
* Authenticate communication between Authelia and reverse proxy.
|
||||||
* Securely transmit authentication data to backends (OAuth2 with bearer tokens).
|
* Securely transmit authentication data to backends (OAuth2 with bearer tokens).
|
||||||
* Protect secrets stored in the database with encryption to prevent secrets leak by database exfiltration.
|
|
||||||
* Least privilege on LDAP binding operations (currently administrative user is used to bind while it could be anonymous
|
* Least privilege on LDAP binding operations (currently administrative user is used to bind while it could be anonymous
|
||||||
for most operations).
|
for most operations).
|
||||||
* Extend the check of user group memberships to authentication backends other than LDAP (File currently).
|
* Extend the check of user group memberships to authentication backends other than LDAP (File currently).
|
||||||
|
|
|
@ -65,7 +65,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthType1FA,
|
Type: regulation.AuthType1FA,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.Ctx.Request.SetBodyString(`{
|
s.mock.Ctx.Request.SetBodyString(`{
|
||||||
|
@ -93,7 +93,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthType1FA,
|
Type: regulation.AuthType1FA,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.Ctx.Request.SetBodyString(`{
|
s.mock.Ctx.Request.SetBodyString(`{
|
||||||
|
@ -119,7 +119,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthType1FA,
|
Type: regulation.AuthType1FA,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.Ctx.Request.SetBodyString(`{
|
s.mock.Ctx.Request.SetBodyString(`{
|
||||||
|
|
|
@ -34,21 +34,21 @@ func (s *HandlerRegisterU2FStep1Suite) TearDownTest() {
|
||||||
s.mock.Close()
|
s.mock.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||||
verification = models.NewIdentityVerification(username, action)
|
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
|
||||||
|
|
||||||
verification.ExpiresAt = expiresAt
|
verification.ExpiresAt = expiresAt
|
||||||
|
|
||||||
claims := verification.ToIdentityVerificationClaim()
|
claims := verification.ToIdentityVerificationClaim()
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
ss, _ := token.SignedString([]byte(secret))
|
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
|
||||||
|
|
||||||
return ss, verification
|
return ss, verification
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
|
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
|
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
|
||||||
Return(true, nil)
|
Return(true, nil)
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||||
|
@ -68,7 +68,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
|
||||||
|
|
||||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
||||||
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
|
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin
|
||||||
Return(true, nil)
|
Return(true, nil)
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||||
|
|
|
@ -97,7 +97,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -286,7 +286,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -414,7 +414,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -497,7 +497,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -546,7 +546,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -591,7 +591,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -640,7 +640,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
@ -687,7 +687,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRegenerateSessionForPreventingSessi
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeDuo,
|
Type: regulation.AuthTypeDuo,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
})).
|
})).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeTOTP,
|
Type: regulation.AuthTypeTOTP,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||||
|
@ -85,7 +85,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeTOTP,
|
Type: regulation.AuthTypeTOTP,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||||
|
@ -115,7 +115,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeTOTP,
|
Type: regulation.AuthTypeTOTP,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||||
|
@ -146,7 +146,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeTOTP,
|
Type: regulation.AuthTypeTOTP,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.TOTPMock.EXPECT().
|
s.mock.TOTPMock.EXPECT().
|
||||||
|
@ -180,7 +180,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeTOTP,
|
Type: regulation.AuthTypeTOTP,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.TOTPMock.EXPECT().
|
s.mock.TOTPMock.EXPECT().
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeU2F,
|
Type: regulation.AuthTypeU2F,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
|
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
|
||||||
|
@ -83,7 +83,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeU2F,
|
Type: regulation.AuthTypeU2F,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||||
|
@ -111,7 +111,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeU2F,
|
Type: regulation.AuthTypeU2F,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||||
|
@ -142,7 +142,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() {
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeU2F,
|
Type: regulation.AuthTypeU2F,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||||
|
@ -171,7 +171,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRegenerateSessionForPreventingSessi
|
||||||
Banned: false,
|
Banned: false,
|
||||||
Time: s.mock.Clock.Now(),
|
Time: s.mock.Clock.Now(),
|
||||||
Type: regulation.AuthTypeU2F,
|
Type: regulation.AuthTypeU2F,
|
||||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||||
|
|
|
@ -27,7 +27,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim)
|
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim, ctx.RemoteIP())
|
||||||
|
|
||||||
// Create the claim with the action to sign it.
|
// Create the claim with the action to sign it.
|
||||||
claims := verification.ToIdentityVerificationClaim()
|
claims := verification.ToIdentityVerificationClaim()
|
||||||
|
@ -183,7 +183,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, claims.ID)
|
err = ctx.Providers.StorageProvider.ConsumeIdentityVerification(ctx, claims.ID, models.NewNullIP(ctx.RemoteIP()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(err, messageOperationFailed)
|
ctx.Error(err, messageOperationFailed)
|
||||||
return
|
return
|
||||||
|
|
|
@ -165,15 +165,15 @@ func (s *IdentityVerificationFinishProcess) TearDownTest() {
|
||||||
s.mock.Close()
|
s.mock.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||||
verification = models.NewIdentityVerification(username, action)
|
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
|
||||||
|
|
||||||
verification.ExpiresAt = expiresAt
|
verification.ExpiresAt = expiresAt
|
||||||
|
|
||||||
claims := verification.ToIdentityVerificationClaim()
|
claims := verification.ToIdentityVerificationClaim()
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
ss, _ := token.SignedString([]byte(secret))
|
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
|
||||||
|
|
||||||
return ss, verification
|
return ss, verification
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() {
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "Login",
|
token, verification := createToken(s.mock, "john", "Login",
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
|
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
@ -229,7 +229,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
||||||
args := newArgs(defaultRetriever)
|
args := newArgs(defaultRetriever)
|
||||||
token, _ := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", args.ActionClaim,
|
token, _ := createToken(s.mock, "john", args.ActionClaim,
|
||||||
time.Now().Add(-1*time.Minute))
|
time.Now().Add(-1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "", "",
|
token, verification := createToken(s.mock, "", "",
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -255,7 +255,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "harry", "EXP_ACTION",
|
token, verification := createToken(s.mock, "harry", "EXP_ACTION",
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() {
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
|
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -281,7 +281,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
|
||||||
Return(true, nil)
|
Return(true, nil)
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||||
Return(fmt.Errorf("cannot remove"))
|
Return(fmt.Errorf("cannot remove"))
|
||||||
|
|
||||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||||
|
@ -291,7 +291,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() {
|
func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() {
|
||||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
|
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
||||||
time.Now().Add(1*time.Minute))
|
time.Now().Add(1*time.Minute))
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete(
|
||||||
Return(true, nil)
|
Return(true, nil)
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
|
|
||||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||||
|
|
|
@ -65,6 +65,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConsumeIdentityVerification mocks base method.
|
||||||
|
func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 models.NullIP) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ConsumeIdentityVerification", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumeIdentityVerification indicates an expected call of ConsumeIdentityVerification.
|
||||||
|
func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// DeletePreferredDuoDevice mocks base method.
|
// DeletePreferredDuoDevice mocks base method.
|
||||||
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
|
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -198,6 +212,21 @@ func (mr *MockStorageMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadU2FDevices mocks base method.
|
||||||
|
func (m *MockStorage) LoadU2FDevices(arg0 context.Context, arg1, arg2 int) ([]models.U2FDevice, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LoadU2FDevices", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].([]models.U2FDevice)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadU2FDevices indicates an expected call of LoadU2FDevices.
|
||||||
|
func (mr *MockStorageMockRecorder) LoadU2FDevices(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevices", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevices), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// LoadUserInfo mocks base method.
|
// LoadUserInfo mocks base method.
|
||||||
func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
|
func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -213,20 +242,6 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveIdentityVerification mocks base method.
|
|
||||||
func (m *MockStorage) RemoveIdentityVerification(arg0 context.Context, arg1 string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification.
|
|
||||||
func (mr *MockStorageMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).RemoveIdentityVerification), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveIdentityVerification mocks base method.
|
// SaveIdentityVerification mocks base method.
|
||||||
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
|
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -442,17 +457,3 @@ func (mr *MockStorageMockRecorder) StartupCheck() *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTOTPConfigurationSecret mocks base method.
|
|
||||||
func (m *MockStorage) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "UpdateTOTPConfigurationSecret", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTOTPConfigurationSecret indicates an expected call of UpdateTOTPConfigurationSecret.
|
|
||||||
func (mr *MockStorageMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockStorage)(nil).UpdateTOTPConfigurationSecret), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ type AuthenticationAttempt struct {
|
||||||
Banned bool `db:"banned"`
|
Banned bool `db:"banned"`
|
||||||
Username string `db:"username"`
|
Username string `db:"username"`
|
||||||
Type string `db:"auth_type"`
|
Type string `db:"auth_type"`
|
||||||
RemoteIP IPAddress `db:"remote_ip"`
|
RemoteIP NullIP `db:"remote_ip"`
|
||||||
RequestURI string `db:"request_uri"`
|
RequestURI string `db:"request_uri"`
|
||||||
RequestMethod string `db:"request_method"`
|
RequestMethod string `db:"request_method"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
@ -8,25 +9,28 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewIdentityVerification creates a new IdentityVerification from a given username and action.
|
// NewIdentityVerification creates a new IdentityVerification from a given username and action.
|
||||||
func NewIdentityVerification(username, action string) (verification IdentityVerification) {
|
func NewIdentityVerification(username, action string, ip net.IP) (verification IdentityVerification) {
|
||||||
return IdentityVerification{
|
return IdentityVerification{
|
||||||
JTI: uuid.New(),
|
JTI: uuid.New(),
|
||||||
IssuedAt: time.Now(),
|
IssuedAt: time.Now(),
|
||||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||||
Action: action,
|
Action: action,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
IssuedIP: NewIP(ip),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IdentityVerification represents an identity verification row in the database.
|
// IdentityVerification represents an identity verification row in the database.
|
||||||
type IdentityVerification struct {
|
type IdentityVerification struct {
|
||||||
ID int `db:"id"`
|
ID int `db:"id"`
|
||||||
JTI uuid.UUID `db:"jti"`
|
JTI uuid.UUID `db:"jti"`
|
||||||
IssuedAt time.Time `db:"iat"`
|
IssuedAt time.Time `db:"iat"`
|
||||||
ExpiresAt time.Time `db:"exp"`
|
IssuedIP IP `db:"issued_ip"`
|
||||||
Used *time.Time `db:"used"`
|
ExpiresAt time.Time `db:"exp"`
|
||||||
Action string `db:"action"`
|
Action string `db:"action"`
|
||||||
Username string `db:"username"`
|
Username string `db:"username"`
|
||||||
|
Consumed *time.Time `db:"consumed"`
|
||||||
|
ConsumedIP NullIP `db:"consumed_ip"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToIdentityVerificationClaim converts the IdentityVerification into a IdentityVerificationClaim.
|
// ToIdentityVerificationClaim converts the IdentityVerification into a IdentityVerificationClaim.
|
||||||
|
|
|
@ -2,23 +2,71 @@ package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewIPAddressFromString converts a string into an IPAddress.
|
// NewIP easily constructs a new IP.
|
||||||
func NewIPAddressFromString(ip string) (ipAddress IPAddress) {
|
func NewIP(value net.IP) (ip IP) {
|
||||||
actualIP := net.ParseIP(ip)
|
return IP{IP: value}
|
||||||
return IPAddress{IP: &actualIP}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddress is a type specific for storage of a net.IP in the database.
|
// NewNullIP easily constructs a new NullIP.
|
||||||
type IPAddress struct {
|
func NewNullIP(value net.IP) (ip NullIP) {
|
||||||
*net.IP
|
return NullIP{IP: value}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value is the IPAddress implementation of the databases/sql driver.Valuer.
|
// NewNullIPFromString easily constructs a new NullIP from a string.
|
||||||
func (ip IPAddress) Value() (value driver.Value, err error) {
|
func NewNullIPFromString(value string) (ip NullIP) {
|
||||||
|
if value == "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
return NullIP{IP: net.ParseIP(value)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP is a type specific for storage of a net.IP in the database which can't be NULL.
|
||||||
|
type IP struct {
|
||||||
|
IP net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value is the IP implementation of the databases/sql driver.Valuer.
|
||||||
|
func (ip IP) Value() (value driver.Value, err error) {
|
||||||
|
if ip.IP == nil {
|
||||||
|
return nil, errors.New("cannot value nil IP to driver.Value")
|
||||||
|
}
|
||||||
|
|
||||||
|
return driver.Value(ip.IP.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan is the IP implementation of the sql.Scanner.
|
||||||
|
func (ip *IP) Scan(src interface{}) (err error) {
|
||||||
|
if src == nil {
|
||||||
|
return errors.New("cannot scan nil to type IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
var value string
|
||||||
|
|
||||||
|
switch v := src.(type) {
|
||||||
|
case string:
|
||||||
|
value = v
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid type %T for IP %v", src, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip.IP = net.ParseIP(value)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NullIP is a type specific for storage of a net.IP in the database which can also be NULL.
|
||||||
|
type NullIP struct {
|
||||||
|
IP net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value is the NullIP implementation of the databases/sql driver.Valuer.
|
||||||
|
func (ip NullIP) Value() (value driver.Value, err error) {
|
||||||
if ip.IP == nil {
|
if ip.IP == nil {
|
||||||
return driver.Value(nil), nil
|
return driver.Value(nil), nil
|
||||||
}
|
}
|
||||||
|
@ -26,8 +74,8 @@ func (ip IPAddress) Value() (value driver.Value, err error) {
|
||||||
return driver.Value(ip.IP.String()), nil
|
return driver.Value(ip.IP.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan is the IPAddress implementation of the sql.Scanner.
|
// Scan is the NullIP implementation of the sql.Scanner.
|
||||||
func (ip *IPAddress) Scan(src interface{}) (err error) {
|
func (ip *NullIP) Scan(src interface{}) (err error) {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
ip.IP = nil
|
ip.IP = nil
|
||||||
return nil
|
return nil
|
||||||
|
@ -39,10 +87,10 @@ func (ip *IPAddress) Scan(src interface{}) (err error) {
|
||||||
case string:
|
case string:
|
||||||
value = v
|
value = v
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid type %T for IPAddress %v", src, src)
|
return fmt.Errorf("invalid type %T for NullIP %v", src, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
*ip.IP = net.ParseIP(value)
|
ip.IP = net.ParseIP(value)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,7 @@ func (r *Regulator) Mark(ctx context.Context, successful, banned bool, username,
|
||||||
Banned: banned,
|
Banned: banned,
|
||||||
Username: username,
|
Username: username,
|
||||||
Type: authType,
|
Type: authType,
|
||||||
RemoteIP: models.IPAddress{IP: &remoteIP},
|
RemoteIP: models.NewNullIP(remoteIP),
|
||||||
RequestURI: requestURI,
|
RequestURI: requestURI,
|
||||||
RequestMethod: requestMethod,
|
RequestMethod: requestMethod,
|
||||||
})
|
})
|
||||||
|
|
|
@ -23,9 +23,11 @@ const (
|
||||||
|
|
||||||
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
|
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
|
||||||
const (
|
const (
|
||||||
tablePre1TOTPSecrets = "totp_secrets"
|
tablePre1TOTPSecrets = "totp_secrets"
|
||||||
tablePre1Config = "config"
|
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
|
||||||
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
|
|
||||||
|
tablePre1Config = "config"
|
||||||
|
|
||||||
tableAlphaAuthenticationLogs = "AuthenticationLogs"
|
tableAlphaAuthenticationLogs = "AuthenticationLogs"
|
||||||
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
|
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
|
||||||
tableAlphaPreferences = "Preferences"
|
tableAlphaPreferences = "Preferences"
|
||||||
|
@ -35,6 +37,15 @@ const (
|
||||||
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
|
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var tablesPre1 = []string{
|
||||||
|
tablePre1TOTPSecrets,
|
||||||
|
tablePre1IdentityVerificationTokens,
|
||||||
|
|
||||||
|
tableUserPreferences,
|
||||||
|
tableU2FDevices,
|
||||||
|
tableAuthenticationLogs,
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
providerAll = "all"
|
providerAll = "all"
|
||||||
providerMySQL = "mysql"
|
providerMySQL = "mysql"
|
||||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
request_uri TEXT NOT NULL,
|
request_uri TEXT NOT NULL,
|
||||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||||
PRIMARY KEY (id)
|
PRIMARY KEY (id)
|
||||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
||||||
id INTEGER AUTO_INCREMENT,
|
id INTEGER AUTO_INCREMENT,
|
||||||
jti CHAR(36),
|
jti CHAR(36),
|
||||||
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
exp TIMESTAMP NOT NULL,
|
exp TIMESTAMP NOT NULL,
|
||||||
used TIMESTAMP NULL DEFAULT NULL,
|
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
action VARCHAR(50) NOT NULL,
|
action VARCHAR(50) NOT NULL,
|
||||||
|
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
PRIMARY KEY (id),
|
PRIMARY KEY (id),
|
||||||
UNIQUE KEY (jti)
|
UNIQUE KEY (jti)
|
||||||
);
|
);
|
||||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
request_uri TEXT,
|
request_uri TEXT,
|
||||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||||
PRIMARY KEY (id)
|
PRIMARY KEY (id)
|
||||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
||||||
id SERIAL,
|
id SERIAL,
|
||||||
jti CHAR(36),
|
jti CHAR(36),
|
||||||
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
exp TIMESTAMP WITH TIME ZONE NOT NULL,
|
exp TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||||
used TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
action VARCHAR(50) NOT NULL,
|
action VARCHAR(50) NOT NULL,
|
||||||
|
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
PRIMARY KEY (id),
|
PRIMARY KEY (id),
|
||||||
UNIQUE (jti)
|
UNIQUE (jti)
|
||||||
);
|
);
|
||||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
request_uri TEXT,
|
request_uri TEXT,
|
||||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||||
PRIMARY KEY (id)
|
PRIMARY KEY (id)
|
||||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
||||||
id INTEGER,
|
id INTEGER,
|
||||||
jti VARCHAR(36),
|
jti VARCHAR(36),
|
||||||
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
exp TIMESTAMP NOT NULL,
|
exp TIMESTAMP NOT NULL,
|
||||||
used TIMESTAMP NULL DEFAULT NULL,
|
|
||||||
username VARCHAR(100) NOT NULL,
|
username VARCHAR(100) NOT NULL,
|
||||||
action VARCHAR(50) NOT NULL,
|
action VARCHAR(50) NOT NULL,
|
||||||
|
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
PRIMARY KEY (id),
|
PRIMARY KEY (id),
|
||||||
UNIQUE (jti)
|
UNIQUE (jti)
|
||||||
);
|
);
|
||||||
|
|
|
@ -18,17 +18,17 @@ type Provider interface {
|
||||||
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
|
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
|
||||||
|
|
||||||
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
|
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
|
||||||
RemoveIdentityVerification(ctx context.Context, jti string) (err error)
|
ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error)
|
||||||
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
||||||
|
|
||||||
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
|
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
|
||||||
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
||||||
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
|
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
|
||||||
LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error)
|
LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error)
|
||||||
UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error)
|
|
||||||
|
|
||||||
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
|
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
|
||||||
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
|
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
|
||||||
|
LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error)
|
||||||
|
|
||||||
SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error)
|
SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error)
|
||||||
DeletePreferredDuoDevice(ctx context.Context, username string) (err error)
|
DeletePreferredDuoDevice(ctx context.Context, username string) (err error)
|
||||||
|
|
|
@ -34,7 +34,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
||||||
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
||||||
|
|
||||||
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
|
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
|
||||||
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
|
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
||||||
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
|
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
|
||||||
|
|
||||||
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
||||||
|
@ -45,8 +45,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
||||||
sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
|
sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
|
||||||
sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
|
sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
|
||||||
|
|
||||||
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
|
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
|
||||||
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
|
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
|
||||||
|
sqlSelectU2FDevices: fmt.Sprintf(queryFmtSelectU2FDevices, tableU2FDevices),
|
||||||
|
|
||||||
|
sqlUpdateU2FDevicePublicKey: fmt.Sprintf(queryFmtUpdateU2FDevicePublicKey, tableU2FDevices),
|
||||||
|
sqlUpdateU2FDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateU2FDevicePublicKeyByUsername, tableU2FDevices),
|
||||||
|
|
||||||
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
|
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
|
||||||
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
|
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
|
||||||
|
@ -86,7 +90,7 @@ type SQLProvider struct {
|
||||||
|
|
||||||
// Table: identity_verification.
|
// Table: identity_verification.
|
||||||
sqlInsertIdentityVerification string
|
sqlInsertIdentityVerification string
|
||||||
sqlDeleteIdentityVerification string
|
sqlConsumeIdentityVerification string
|
||||||
sqlSelectExistsIdentityVerification string
|
sqlSelectExistsIdentityVerification string
|
||||||
|
|
||||||
// Table: totp_configurations.
|
// Table: totp_configurations.
|
||||||
|
@ -99,8 +103,12 @@ type SQLProvider struct {
|
||||||
sqlUpdateTOTPConfigSecretByUsername string
|
sqlUpdateTOTPConfigSecretByUsername string
|
||||||
|
|
||||||
// Table: u2f_devices.
|
// Table: u2f_devices.
|
||||||
sqlUpsertU2FDevice string
|
sqlUpsertU2FDevice string
|
||||||
sqlSelectU2FDevice string
|
sqlSelectU2FDevice string
|
||||||
|
sqlSelectU2FDevices string
|
||||||
|
|
||||||
|
sqlUpdateU2FDevicePublicKey string
|
||||||
|
sqlUpdateU2FDevicePublicKeyByUsername string
|
||||||
|
|
||||||
// Table: duo_devices
|
// Table: duo_devices
|
||||||
sqlUpsertDuoDevice string
|
sqlUpsertDuoDevice string
|
||||||
|
@ -217,7 +225,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
|
||||||
// SaveIdentityVerification save an identity verification record to the database.
|
// SaveIdentityVerification save an identity verification record to the database.
|
||||||
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
|
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
|
||||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
|
||||||
verification.JTI, verification.IssuedAt, verification.ExpiresAt,
|
verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
|
||||||
verification.Username, verification.Action); err != nil {
|
verification.Username, verification.Action); err != nil {
|
||||||
return fmt.Errorf("error inserting identity verification: %w", err)
|
return fmt.Errorf("error inserting identity verification: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -225,9 +233,9 @@ func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveIdentityVerification remove an identity verification record from the database.
|
// ConsumeIdentityVerification marks an identity verification record in the database as consumed.
|
||||||
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, jti string) (err error) {
|
func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error) {
|
||||||
if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, jti); err != nil {
|
if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
|
||||||
return fmt.Errorf("error updating identity verification: %w", err)
|
return fmt.Errorf("error updating identity verification: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,8 +329,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
|
||||||
return configs, nil
|
return configs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTOTPConfigurationSecret updates a TOTP configuration secret.
|
func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
|
||||||
func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
|
|
||||||
switch config.ID {
|
switch config.ID {
|
||||||
case 0:
|
case 0:
|
||||||
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
|
||||||
|
@ -339,6 +346,10 @@ func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config
|
||||||
|
|
||||||
// SaveU2FDevice saves a registered U2F device.
|
// SaveU2FDevice saves a registered U2F device.
|
||||||
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
|
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
|
||||||
|
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
|
||||||
|
return fmt.Errorf("error encrypting the U2F device public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.Description, device.KeyHandle, device.PublicKey); err != nil {
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.Description, device.KeyHandle, device.PublicKey); err != nil {
|
||||||
return fmt.Errorf("error upserting U2F device: %v", err)
|
return fmt.Errorf("error upserting U2F device: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -348,9 +359,7 @@ func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice
|
||||||
|
|
||||||
// LoadU2FDevice loads a U2F device registration for a given username.
|
// LoadU2FDevice loads a U2F device registration for a given username.
|
||||||
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
|
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
|
||||||
device = &models.U2FDevice{
|
device = &models.U2FDevice{}
|
||||||
Username: username,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil {
|
if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
@ -360,9 +369,64 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic
|
||||||
return nil, fmt.Errorf("error selecting U2F device: %w", err)
|
return nil, fmt.Errorf("error selecting U2F device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return device, nil
|
return device, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadU2FDevices loads U2F device registrations.
|
||||||
|
func (p *SQLProvider) LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error) {
|
||||||
|
rows, err := p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, limit, limit*page)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return devices, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("error selecting U2F devices: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
p.log.Errorf(logFmtErrClosingConn, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
devices = make([]models.U2FDevice, 0, limit)
|
||||||
|
|
||||||
|
var device models.U2FDevice
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.StructScan(&device); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning U2F device to struct: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
devices = append(devices, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
return devices, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) updateU2FDevicePublicKey(ctx context.Context, device models.U2FDevice) (err error) {
|
||||||
|
switch device.ID {
|
||||||
|
case 0:
|
||||||
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKeyByUsername, device.PublicKey, device.Username)
|
||||||
|
default:
|
||||||
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKey, device.PublicKey, device.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error updating U2F public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SavePreferredDuoDevice saves a Duo device.
|
// SavePreferredDuoDevice saves a Duo device.
|
||||||
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) {
|
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) {
|
||||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method)
|
_, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method)
|
||||||
|
|
|
@ -38,13 +38,16 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
|
||||||
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
|
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
|
||||||
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
|
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
|
||||||
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
||||||
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
|
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
|
||||||
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
||||||
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
|
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
|
||||||
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
|
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
|
||||||
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
|
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
|
||||||
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
|
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
|
||||||
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
|
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
|
||||||
|
provider.sqlSelectU2FDevices = provider.db.Rebind(provider.sqlSelectU2FDevices)
|
||||||
|
provider.sqlUpdateU2FDevicePublicKey = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKey)
|
||||||
|
provider.sqlUpdateU2FDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKeyByUsername)
|
||||||
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
|
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
|
||||||
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
|
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
|
||||||
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
|
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
|
||||||
|
|
|
@ -22,6 +22,26 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
||||||
|
|
||||||
key := sha256.Sum256([]byte(encryptionKey))
|
key := sha256.Sum256([]byte(encryptionKey))
|
||||||
|
|
||||||
|
if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.schemaEncryptionChangeKeyU2F(ctx, tx, key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("rollback due to error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||||
var configs []models.TOTPConfiguration
|
var configs []models.TOTPConfiguration
|
||||||
|
|
||||||
for page := 0; true; page++ {
|
for page := 0; true; page++ {
|
||||||
|
@ -42,7 +62,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
||||||
return fmt.Errorf("rollback due to error: %w", err)
|
return fmt.Errorf("rollback due to error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.UpdateTOTPConfigurationSecret(ctx, config); err != nil {
|
if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||||
}
|
}
|
||||||
|
@ -56,15 +76,45 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
|
return nil
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
}
|
||||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
||||||
|
func (p *SQLProvider) schemaEncryptionChangeKeyU2F(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||||
|
var devices []models.U2FDevice
|
||||||
|
|
||||||
|
for page := 0; true; page++ {
|
||||||
|
if devices, err = p.LoadU2FDevices(ctx, 10, page); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("rollback due to error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("rollback due to error: %w", err)
|
for _, device := range devices {
|
||||||
|
if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("rollback due to error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.updateU2FDevicePublicKey(ctx, device); err != nil {
|
||||||
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||||
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("rollback due to error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(devices) != 10 {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
|
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
|
||||||
|
@ -85,49 +135,12 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
var (
|
if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
|
||||||
config models.TOTPConfiguration
|
errs = append(errs, err)
|
||||||
row int
|
|
||||||
invalid int
|
|
||||||
total int
|
|
||||||
)
|
|
||||||
|
|
||||||
pageSize := 10
|
|
||||||
|
|
||||||
var rows *sqlx.Rows
|
|
||||||
|
|
||||||
for page := 0; true; page++ {
|
|
||||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
|
|
||||||
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
row = 0
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
total++
|
|
||||||
row++
|
|
||||||
|
|
||||||
if err = rows.StructScan(&config); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = p.decrypt(config.Secret); err != nil {
|
|
||||||
invalid++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = rows.Close()
|
|
||||||
|
|
||||||
if row < pageSize {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalid != 0 {
|
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
|
||||||
errs = append(errs, fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total))
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,6 +161,104 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
|
||||||
|
var (
|
||||||
|
config models.TOTPConfiguration
|
||||||
|
row int
|
||||||
|
invalid int
|
||||||
|
total int
|
||||||
|
)
|
||||||
|
|
||||||
|
pageSize := 10
|
||||||
|
|
||||||
|
var rows *sqlx.Rows
|
||||||
|
|
||||||
|
for page := 0; true; page++ {
|
||||||
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row = 0
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
total++
|
||||||
|
row++
|
||||||
|
|
||||||
|
if err = rows.StructScan(&config); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = p.decrypt(config.Secret); err != nil {
|
||||||
|
invalid++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
if row < pageSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if invalid != 0 {
|
||||||
|
return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
|
||||||
|
var (
|
||||||
|
device models.U2FDevice
|
||||||
|
row int
|
||||||
|
invalid int
|
||||||
|
total int
|
||||||
|
)
|
||||||
|
|
||||||
|
pageSize := 10
|
||||||
|
|
||||||
|
var rows *sqlx.Rows
|
||||||
|
|
||||||
|
for page := 0; true; page++ {
|
||||||
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, pageSize, pageSize*page); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return fmt.Errorf("error selecting U2F devices: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row = 0
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
total++
|
||||||
|
row++
|
||||||
|
|
||||||
|
if err = rows.StructScan(&device); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return fmt.Errorf("error scanning U2F device to struct: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = p.decrypt(device.PublicKey); err != nil {
|
||||||
|
invalid++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
if row < pageSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if invalid != 0 {
|
||||||
|
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||||
return utils.Encrypt(clearText, &p.key)
|
return utils.Encrypt(clearText, &p.key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,16 +60,16 @@ const (
|
||||||
SELECT EXISTS (
|
SELECT EXISTS (
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM %s
|
FROM %s
|
||||||
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND used IS NULL
|
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND consumed IS NULL
|
||||||
);`
|
);`
|
||||||
|
|
||||||
queryFmtInsertIdentityVerification = `
|
queryFmtInsertIdentityVerification = `
|
||||||
INSERT INTO %s (jti, iat, exp, username, action)
|
INSERT INTO %s (jti, iat, issued_ip, exp, username, action)
|
||||||
VALUES (?, ?, ?, ?, ?);`
|
VALUES (?, ?, ?, ?, ?, ?);`
|
||||||
|
|
||||||
queryFmtDeleteIdentityVerification = `
|
queryFmtConsumeIdentityVerification = `
|
||||||
UPDATE %s
|
UPDATE %s
|
||||||
SET used = CURRENT_TIMESTAMP
|
SET consumed = CURRENT_TIMESTAMP, consumed_ip = ?
|
||||||
WHERE jti = ?;`
|
WHERE jti = ?;`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -114,10 +114,26 @@ const (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
queryFmtSelectU2FDevice = `
|
queryFmtSelectU2FDevice = `
|
||||||
SELECT key_handle, public_key
|
SELECT id, username, key_handle, public_key
|
||||||
FROM %s
|
FROM %s
|
||||||
WHERE username = ?;`
|
WHERE username = ?;`
|
||||||
|
|
||||||
|
queryFmtSelectU2FDevices = `
|
||||||
|
SELECT id, username, key_handle, public_key
|
||||||
|
FROM %s
|
||||||
|
LIMIT ?
|
||||||
|
OFFSET ?;`
|
||||||
|
|
||||||
|
queryFmtUpdateU2FDevicePublicKey = `
|
||||||
|
UPDATE %s
|
||||||
|
SET public_key = ?
|
||||||
|
WHERE id = ?;`
|
||||||
|
|
||||||
|
queryFmtUpdateUpdateU2FDevicePublicKeyByUsername = `
|
||||||
|
UPDATE %s
|
||||||
|
SET public_key = ?
|
||||||
|
WHERE username = ?;`
|
||||||
|
|
||||||
queryFmtUpsertU2FDevice = `
|
queryFmtUpsertU2FDevice = `
|
||||||
REPLACE INTO %s (username, description, key_handle, public_key)
|
REPLACE INTO %s (username, description, key_handle, public_key)
|
||||||
VALUES (?, ?, ?, ?);`
|
VALUES (?, ?, ?, ?);`
|
||||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -57,9 +58,13 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error
|
||||||
return migration.After, nil
|
return migration.After, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) &&
|
var tablesV1 = []string{tableDuoDevices, tableEncryption, tableIdentityVerification, tableMigrations, tableTOTPConfigurations}
|
||||||
utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) &&
|
|
||||||
utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) {
|
if utils.IsStringSliceContainsAll(tablesPre1, tables) {
|
||||||
|
if utils.IsStringSliceContainsAny(tablesV1, tables) {
|
||||||
|
return -2, errors.New("pre1 schema contains v1 tables it shouldn't contain")
|
||||||
|
}
|
||||||
|
|
||||||
return -1, nil
|
return -1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -267,7 +267,12 @@ func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey})
|
encryptedPublicKey, err := p.encrypt(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: encryptedPublicKey})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, device := range devices {
|
for _, device := range devices {
|
||||||
|
@ -446,6 +451,11 @@ func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.PublicKey, err = p.decrypt(device.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
devices = append(devices, device)
|
devices = append(devices, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,17 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles.
|
||||||
|
func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) {
|
||||||
|
for _, n := range needles {
|
||||||
|
if IsStringInSlice(n, haystack) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// SliceString splits a string s into an array with each item being a max of int d
|
// SliceString splits a string s into an array with each item being a max of int d
|
||||||
// d = denominator, n = numerator, q = quotient, r = remainder.
|
// d = denominator, n = numerator, q = quotient, r = remainder.
|
||||||
func SliceString(s string, d int) (array []string) {
|
func SliceString(s string, d int) (array []string) {
|
||||||
|
|
|
@ -162,3 +162,12 @@ func TestIsStringSliceContainsAll(t *testing.T) {
|
||||||
assert.True(t, IsStringSliceContainsAll(needles, haystackOne))
|
assert.True(t, IsStringSliceContainsAll(needles, haystackOne))
|
||||||
assert.False(t, IsStringSliceContainsAll(needles, haystackTwo))
|
assert.False(t, IsStringSliceContainsAll(needles, haystackTwo))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsStringSliceContainsAny(t *testing.T) {
|
||||||
|
needles := []string{"abc", "123", "xyz"}
|
||||||
|
haystackOne := []string{"tvu", "456", "hij"}
|
||||||
|
haystackTwo := []string{"tvu", "123", "456", "xyz"}
|
||||||
|
|
||||||
|
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
|
||||||
|
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue