From 255aaeb2ad4ab8bad95fac93707066ac01e11e0d Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 3 Dec 2021 11:04:11 +1100 Subject: [PATCH] feat(storage): encrypt u2f key (#2664) Adds encryption to the U2F public keys. While the public keys cannot be used to authenticate, only to validate someone is authenticated, if a rogue operator changed these in the database they may be able to bypass 2FA. This prevents that. --- docs/configuration/storage/index.md | 10 +- docs/security/measures.md | 48 ++++ docs/security/threat-model.md | 1 - internal/handlers/handler_firstfactor_test.go | 6 +- .../handler_register_u2f_step1_test.go | 14 +- internal/handlers/handler_sign_duo_test.go | 16 +- internal/handlers/handler_sign_totp_test.go | 10 +- .../handlers/handler_sign_u2f_step2_test.go | 10 +- internal/middlewares/identity_verification.go | 4 +- .../middlewares/identity_verification_test.go | 22 +- internal/mocks/storage.go | 57 ++--- internal/models/authentication_attempt.go | 2 +- internal/models/identity_verification.go | 20 +- internal/models/types.go | 74 +++++-- internal/regulation/regulator.go | 2 +- internal/storage/const.go | 17 +- .../V0001.Initial_Schema.mysql.up.sql | 6 +- .../V0001.Initial_Schema.postgres.up.sql | 6 +- .../V0001.Initial_Schema.sqlite.up.sql | 6 +- internal/storage/provider.go | 4 +- internal/storage/sql_provider.go | 94 ++++++-- .../storage/sql_provider_backend_postgres.go | 5 +- internal/storage/sql_provider_encryption.go | 205 ++++++++++++++---- internal/storage/sql_provider_queries.go | 28 ++- internal/storage/sql_provider_schema.go | 11 +- internal/storage/sql_provider_schema_pre1.go | 12 +- internal/utils/strings.go | 11 + internal/utils/strings_test.go | 9 + 28 files changed, 528 insertions(+), 182 deletions(-) diff --git a/docs/configuration/storage/index.md b/docs/configuration/storage/index.md index 7a1881289..4b6caae8e 100644 --- a/docs/configuration/storage/index.md +++ b/docs/configuration/storage/index.md @@ -31,12 +31,12 @@ required: yes {: .label .label-config .label-red } -The encryption key used to encrypt data in the database. It has a minimum length of 20 and must be provided. We encrypt -data by creating a sha256 checksum of the provided value, and use that to encrypt the data with the AES-GCM 256bit -algorithm. +The encryption key used to encrypt data in the database. We encrypt data by creating a sha256 checksum of the provided +value, and use that to encrypt the data with the AES-GCM 256bit algorithm. -The encrypted data in the database is as follows: -- TOTP Secret +The minimum length of this key is 20 characters, however we generally recommend above 64 characters. + +See [securty measures](../../security/measures.md#storage-security-measures) for more information. ### local See [SQLite](./sqlite.md). diff --git a/docs/security/measures.md b/docs/security/measures.md index ba4bcc867..aca32052a 100644 --- a/docs/security/measures.md +++ b/docs/security/measures.md @@ -81,6 +81,54 @@ LDAP implementations vary, so please ask if you need some assistance in configur These protections can be [tuned](../configuration/authentication/ldap.md#refresh-interval) according to your security policy by changing refresh_interval, however we believe that 5 minutes is a fairly safe interval. +## Storage security measures + +We force users to encrypt vulnerable data stored in the database. It is strongly advised you do not give this encryption +key to anyone. In the instance of a database installation that multiple users have access to, you should aim to ensure +that users who have access to the database do not also have access to this key. + +The encrypted data in the database is as follows: + +|Table |Column |Rational | +|:-----------------:|:--------:|:----------------------------------------------------------------------------------------------------:| +|totp_configurations|secret |Prevents a [Leaked Database](#leaked-database) or [Bad Actors](#bad-actors) from compromising security| +|u2f_devices |public_key|Prevents [Bad Actors](#bad-actors) from compromising security | + +### Leaked Database + +A leaked database can reasonably compromise security if there are credentials that are not encrypted. Columns encrypted +for this purpose prevent this attack vector. + +### Bad Actors + +A bad actor who has the SQL password and access to the database can theoretically change another users credential, this +theoretically bypasses authentication. Columns encrypted for this purpose prevent this attack vector. + +A bad actor may also be able to use data in the database to bypass 2FA silently depending on the credentials. In the +instance of the U2F public key this is not possible, they can only change it which would eventually alert the user in +question. But in the case of TOTP they can use the secret to authenticate without knowledge of the user in question. + +### Encryption key management + +You must supply the encryption key in the recommended method of a [secret](../configuration/secrets.md) or in one of +the other methods available for [configuration](../configuration/index.md#configuration). + +If you wish to change your encryption key for any reason you can do so using the following steps: + +1. Run the `authelia --version` command to determine the version of Authelia you're running and either download that + version or run another container of that version interactively. All the subsequent commands assume you're running + the `authelia` binary in the current working directory. You will have to adjust this according to how you're running + it. +2. Run the `./authelia storage encryption change-key --help` command. +3. Stop Authelia. + - You can skip this step, however note that any data changed between the time you make the change and the time when + you stop Authelia i.e. via user registering a device; will be encrypted with the incorrect key. +4. Run the `./authelia storage encryption change-key` command with the appropriate parameters. + - The help from step 1 will be useful here. The easiest method to accomplish this is with the `--config`, + `--encryption-key`, and `--new-encryption-key` parameters. +5. Update the encryption key Authelia uses on startup. +6. Start Authelia. + ## Notifier security measures (SMTP) The SMTP Notifier implementation does not allow connections that are not secure without changing default configuration diff --git a/docs/security/threat-model.md b/docs/security/threat-model.md index b0ac55786..0f549e301 100644 --- a/docs/security/threat-model.md +++ b/docs/security/threat-model.md @@ -52,7 +52,6 @@ If properly configured, Authelia guarantees the following for security of your u * Binding session cookies to single IP addresses. * Authenticate communication between Authelia and reverse proxy. * Securely transmit authentication data to backends (OAuth2 with bearer tokens). -* Protect secrets stored in the database with encryption to prevent secrets leak by database exfiltration. * Least privilege on LDAP binding operations (currently administrative user is used to bind while it could be anonymous for most operations). * Extend the check of user group memberships to authentication backends other than LDAP (File currently). diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index bf332ba34..fe26016a7 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -65,7 +65,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthType1FA, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.Ctx.Request.SetBodyString(`{ @@ -93,7 +93,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthType1FA, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.Ctx.Request.SetBodyString(`{ @@ -119,7 +119,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthType1FA, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.Ctx.Request.SetBodyString(`{ diff --git a/internal/handlers/handler_register_u2f_step1_test.go b/internal/handlers/handler_register_u2f_step1_test.go index 5e6707e1e..42a566bbe 100644 --- a/internal/handlers/handler_register_u2f_step1_test.go +++ b/internal/handlers/handler_register_u2f_step1_test.go @@ -34,21 +34,21 @@ func (s *HandlerRegisterU2FStep1Suite) TearDownTest() { s.mock.Close() } -func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) { - verification = models.NewIdentityVerification(username, action) +func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) { + verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP()) verification.ExpiresAt = expiresAt claims := verification.ToIdentityVerificationClaim() token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, _ := token.SignedString([]byte(secret)) + ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret)) return ss, verification } func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, + token, verification := createToken(s.mock, "john", ActionU2FRegistration, time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -57,7 +57,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi Return(true, nil) s.mock.StorageMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). + ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) @@ -68,7 +68,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() { s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, + token, verification := createToken(s.mock, "john", ActionU2FRegistration, time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -77,7 +77,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin Return(true, nil) s.mock.StorageMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). + ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) diff --git a/internal/handlers/handler_sign_duo_test.go b/internal/handlers/handler_sign_duo_test.go index 71bfdd62a..e8959a5ca 100644 --- a/internal/handlers/handler_sign_duo_test.go +++ b/internal/handlers/handler_sign_duo_test.go @@ -97,7 +97,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -286,7 +286,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -414,7 +414,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -497,7 +497,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -546,7 +546,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -591,7 +591,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -640,7 +640,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) @@ -687,7 +687,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRegenerateSessionForPreventingSessi Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeDuo, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })). Return(nil) diff --git a/internal/handlers/handler_sign_totp_test.go b/internal/handlers/handler_sign_totp_test.go index b0163a0e1..84d3a3b2b 100644 --- a/internal/handlers/handler_sign_totp_test.go +++ b/internal/handlers/handler_sign_totp_test.go @@ -51,7 +51,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeTOTP, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) @@ -85,7 +85,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeTOTP, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) @@ -115,7 +115,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeTOTP, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) @@ -146,7 +146,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeTOTP, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.TOTPMock.EXPECT(). @@ -180,7 +180,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeTOTP, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.TOTPMock.EXPECT(). diff --git a/internal/handlers/handler_sign_u2f_step2_test.go b/internal/handlers/handler_sign_u2f_step2_test.go index 7fb0ec180..0e050f36d 100644 --- a/internal/handlers/handler_sign_u2f_step2_test.go +++ b/internal/handlers/handler_sign_u2f_step2_test.go @@ -51,7 +51,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeU2F, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL @@ -83,7 +83,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeU2F, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) bodyBytes, err := json.Marshal(signU2FRequestBody{ @@ -111,7 +111,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeU2F, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) bodyBytes, err := json.Marshal(signU2FRequestBody{ @@ -142,7 +142,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() { Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeU2F, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) bodyBytes, err := json.Marshal(signU2FRequestBody{ @@ -171,7 +171,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRegenerateSessionForPreventingSessi Banned: false, Time: s.mock.Clock.Now(), Type: regulation.AuthTypeU2F, - RemoteIP: models.NewIPAddressFromString("0.0.0.0"), + RemoteIP: models.NewNullIPFromString("0.0.0.0"), })) bodyBytes, err := json.Marshal(signU2FRequestBody{ diff --git a/internal/middlewares/identity_verification.go b/internal/middlewares/identity_verification.go index 04a5bf240..96751cd92 100644 --- a/internal/middlewares/identity_verification.go +++ b/internal/middlewares/identity_verification.go @@ -27,7 +27,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle return } - verification := models.NewIdentityVerification(identity.Username, args.ActionClaim) + verification := models.NewIdentityVerification(identity.Username, args.ActionClaim, ctx.RemoteIP()) // Create the claim with the action to sign it. claims := verification.ToIdentityVerificationClaim() @@ -183,7 +183,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c return } - err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, claims.ID) + err = ctx.Providers.StorageProvider.ConsumeIdentityVerification(ctx, claims.ID, models.NewNullIP(ctx.RemoteIP())) if err != nil { ctx.Error(err, messageOperationFailed) return diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index a557aeee8..905036963 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -165,15 +165,15 @@ func (s *IdentityVerificationFinishProcess) TearDownTest() { s.mock.Close() } -func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) { - verification = models.NewIdentityVerification(username, action) +func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) { + verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP()) verification.ExpiresAt = expiresAt claims := verification.ToIdentityVerificationClaim() token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, _ := token.SignedString([]byte(secret)) + ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret)) return ss, verification } @@ -203,7 +203,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided() } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "Login", + token, verification := createToken(s.mock, "john", "Login", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -229,7 +229,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() { func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() { args := newArgs(defaultRetriever) - token, _ := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", args.ActionClaim, + token, _ := createToken(s.mock, "john", args.ActionClaim, time.Now().Add(-1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -240,7 +240,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() { } func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "", "", + token, verification := createToken(s.mock, "", "", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -255,7 +255,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() { } func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "harry", "EXP_ACTION", + token, verification := createToken(s.mock, "harry", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -272,7 +272,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() { } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION", + token, verification := createToken(s.mock, "john", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -281,7 +281,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved Return(true, nil) s.mock.StorageMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). + ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(fmt.Errorf("cannot remove")) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -291,7 +291,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved } func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() { - token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION", + token, verification := createToken(s.mock, "john", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) @@ -300,7 +300,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete( Return(true, nil) s.mock.StorageMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). + ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) diff --git a/internal/mocks/storage.go b/internal/mocks/storage.go index 6212c17f9..0e3014901 100644 --- a/internal/mocks/storage.go +++ b/internal/mocks/storage.go @@ -65,6 +65,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close)) } +// ConsumeIdentityVerification mocks base method. +func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 models.NullIP) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConsumeIdentityVerification", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ConsumeIdentityVerification indicates an expected call of ConsumeIdentityVerification. +func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2) +} + // DeletePreferredDuoDevice mocks base method. func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -198,6 +212,21 @@ func (mr *MockStorageMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1) } +// LoadU2FDevices mocks base method. +func (m *MockStorage) LoadU2FDevices(arg0 context.Context, arg1, arg2 int) ([]models.U2FDevice, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadU2FDevices", arg0, arg1, arg2) + ret0, _ := ret[0].([]models.U2FDevice) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadU2FDevices indicates an expected call of LoadU2FDevices. +func (mr *MockStorageMockRecorder) LoadU2FDevices(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevices", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevices), arg0, arg1, arg2) +} + // LoadUserInfo mocks base method. func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) { m.ctrl.T.Helper() @@ -213,20 +242,6 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1) } -// RemoveIdentityVerification mocks base method. -func (m *MockStorage) RemoveIdentityVerification(arg0 context.Context, arg1 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification. -func (mr *MockStorageMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).RemoveIdentityVerification), arg0, arg1) -} - // SaveIdentityVerification mocks base method. func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error { m.ctrl.T.Helper() @@ -442,17 +457,3 @@ func (mr *MockStorageMockRecorder) StartupCheck() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck)) } - -// UpdateTOTPConfigurationSecret mocks base method. -func (m *MockStorage) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTOTPConfigurationSecret", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateTOTPConfigurationSecret indicates an expected call of UpdateTOTPConfigurationSecret. -func (mr *MockStorageMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockStorage)(nil).UpdateTOTPConfigurationSecret), arg0, arg1) -} diff --git a/internal/models/authentication_attempt.go b/internal/models/authentication_attempt.go index 01812812c..36ad1f31f 100644 --- a/internal/models/authentication_attempt.go +++ b/internal/models/authentication_attempt.go @@ -12,7 +12,7 @@ type AuthenticationAttempt struct { Banned bool `db:"banned"` Username string `db:"username"` Type string `db:"auth_type"` - RemoteIP IPAddress `db:"remote_ip"` + RemoteIP NullIP `db:"remote_ip"` RequestURI string `db:"request_uri"` RequestMethod string `db:"request_method"` } diff --git a/internal/models/identity_verification.go b/internal/models/identity_verification.go index e5e180c1e..7eaf1edb3 100644 --- a/internal/models/identity_verification.go +++ b/internal/models/identity_verification.go @@ -1,6 +1,7 @@ package models import ( + "net" "time" "github.com/golang-jwt/jwt/v4" @@ -8,25 +9,28 @@ import ( ) // NewIdentityVerification creates a new IdentityVerification from a given username and action. -func NewIdentityVerification(username, action string) (verification IdentityVerification) { +func NewIdentityVerification(username, action string, ip net.IP) (verification IdentityVerification) { return IdentityVerification{ JTI: uuid.New(), IssuedAt: time.Now(), ExpiresAt: time.Now().Add(5 * time.Minute), Action: action, Username: username, + IssuedIP: NewIP(ip), } } // IdentityVerification represents an identity verification row in the database. type IdentityVerification struct { - ID int `db:"id"` - JTI uuid.UUID `db:"jti"` - IssuedAt time.Time `db:"iat"` - ExpiresAt time.Time `db:"exp"` - Used *time.Time `db:"used"` - Action string `db:"action"` - Username string `db:"username"` + ID int `db:"id"` + JTI uuid.UUID `db:"jti"` + IssuedAt time.Time `db:"iat"` + IssuedIP IP `db:"issued_ip"` + ExpiresAt time.Time `db:"exp"` + Action string `db:"action"` + Username string `db:"username"` + Consumed *time.Time `db:"consumed"` + ConsumedIP NullIP `db:"consumed_ip"` } // ToIdentityVerificationClaim converts the IdentityVerification into a IdentityVerificationClaim. diff --git a/internal/models/types.go b/internal/models/types.go index f1cba4485..a0c6e7e82 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -2,23 +2,71 @@ package models import ( "database/sql/driver" + "errors" "fmt" "net" ) -// NewIPAddressFromString converts a string into an IPAddress. -func NewIPAddressFromString(ip string) (ipAddress IPAddress) { - actualIP := net.ParseIP(ip) - return IPAddress{IP: &actualIP} +// NewIP easily constructs a new IP. +func NewIP(value net.IP) (ip IP) { + return IP{IP: value} } -// IPAddress is a type specific for storage of a net.IP in the database. -type IPAddress struct { - *net.IP +// NewNullIP easily constructs a new NullIP. +func NewNullIP(value net.IP) (ip NullIP) { + return NullIP{IP: value} } -// Value is the IPAddress implementation of the databases/sql driver.Valuer. -func (ip IPAddress) Value() (value driver.Value, err error) { +// NewNullIPFromString easily constructs a new NullIP from a string. +func NewNullIPFromString(value string) (ip NullIP) { + if value == "" { + return ip + } + + return NullIP{IP: net.ParseIP(value)} +} + +// IP is a type specific for storage of a net.IP in the database which can't be NULL. +type IP struct { + IP net.IP +} + +// Value is the IP implementation of the databases/sql driver.Valuer. +func (ip IP) Value() (value driver.Value, err error) { + if ip.IP == nil { + return nil, errors.New("cannot value nil IP to driver.Value") + } + + return driver.Value(ip.IP.String()), nil +} + +// Scan is the IP implementation of the sql.Scanner. +func (ip *IP) Scan(src interface{}) (err error) { + if src == nil { + return errors.New("cannot scan nil to type IP") + } + + var value string + + switch v := src.(type) { + case string: + value = v + default: + return fmt.Errorf("invalid type %T for IP %v", src, src) + } + + ip.IP = net.ParseIP(value) + + return nil +} + +// NullIP is a type specific for storage of a net.IP in the database which can also be NULL. +type NullIP struct { + IP net.IP +} + +// Value is the NullIP implementation of the databases/sql driver.Valuer. +func (ip NullIP) Value() (value driver.Value, err error) { if ip.IP == nil { return driver.Value(nil), nil } @@ -26,8 +74,8 @@ func (ip IPAddress) Value() (value driver.Value, err error) { return driver.Value(ip.IP.String()), nil } -// Scan is the IPAddress implementation of the sql.Scanner. -func (ip *IPAddress) Scan(src interface{}) (err error) { +// Scan is the NullIP implementation of the sql.Scanner. +func (ip *NullIP) Scan(src interface{}) (err error) { if src == nil { ip.IP = nil return nil @@ -39,10 +87,10 @@ func (ip *IPAddress) Scan(src interface{}) (err error) { case string: value = v default: - return fmt.Errorf("invalid type %T for IPAddress %v", src, src) + return fmt.Errorf("invalid type %T for NullIP %v", src, src) } - *ip.IP = net.ParseIP(value) + ip.IP = net.ParseIP(value) return nil } diff --git a/internal/regulation/regulator.go b/internal/regulation/regulator.go index 5b59081ba..6104c1504 100644 --- a/internal/regulation/regulator.go +++ b/internal/regulation/regulator.go @@ -51,7 +51,7 @@ func (r *Regulator) Mark(ctx context.Context, successful, banned bool, username, Banned: banned, Username: username, Type: authType, - RemoteIP: models.IPAddress{IP: &remoteIP}, + RemoteIP: models.NewNullIP(remoteIP), RequestURI: requestURI, RequestMethod: requestMethod, }) diff --git a/internal/storage/const.go b/internal/storage/const.go index 22455c0df..efb4389c4 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -23,9 +23,11 @@ const ( // WARNING: Do not change/remove these consts. They are used for Pre1 migrations. const ( - tablePre1TOTPSecrets = "totp_secrets" - tablePre1Config = "config" - tablePre1IdentityVerificationTokens = "identity_verification_tokens" + tablePre1TOTPSecrets = "totp_secrets" + tablePre1IdentityVerificationTokens = "identity_verification_tokens" + + tablePre1Config = "config" + tableAlphaAuthenticationLogs = "AuthenticationLogs" tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens" tableAlphaPreferences = "Preferences" @@ -35,6 +37,15 @@ const ( tableAlphaU2FDeviceHandles = "U2FDeviceHandles" ) +var tablesPre1 = []string{ + tablePre1TOTPSecrets, + tablePre1IdentityVerificationTokens, + + tableUserPreferences, + tableU2FDevices, + tableAuthenticationLogs, +} + const ( providerAll = "all" providerMySQL = "mysql" diff --git a/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql b/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql index 1e6dda17a..0d2a4283e 100644 --- a/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql +++ b/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql @@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs ( banned BOOLEAN NOT NULL DEFAULT FALSE, username VARCHAR(100) NOT NULL, auth_type VARCHAR(8) NOT NULL DEFAULT '1FA', - remote_ip VARCHAR(47) NULL DEFAULT NULL, + remote_ip VARCHAR(39) NULL DEFAULT NULL, request_uri TEXT NOT NULL, request_method VARCHAR(8) NOT NULL DEFAULT '', PRIMARY KEY (id) @@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification ( id INTEGER AUTO_INCREMENT, jti CHAR(36), iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, exp TIMESTAMP NOT NULL, - used TIMESTAMP NULL DEFAULT NULL, username VARCHAR(100) NOT NULL, action VARCHAR(50) NOT NULL, + consumed TIMESTAMP NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, PRIMARY KEY (id), UNIQUE KEY (jti) ); diff --git a/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql b/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql index 30beeaa30..c56e70421 100644 --- a/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql +++ b/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql @@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs ( banned BOOLEAN NOT NULL DEFAULT FALSE, username VARCHAR(100) NOT NULL, auth_type VARCHAR(8) NOT NULL DEFAULT '1FA', - remote_ip VARCHAR(47) NULL DEFAULT NULL, + remote_ip VARCHAR(39) NULL DEFAULT NULL, request_uri TEXT, request_method VARCHAR(8) NOT NULL DEFAULT '', PRIMARY KEY (id) @@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification ( id SERIAL, jti CHAR(36), iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, exp TIMESTAMP WITH TIME ZONE NOT NULL, - used TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, username VARCHAR(100) NOT NULL, action VARCHAR(50) NOT NULL, + consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, PRIMARY KEY (id), UNIQUE (jti) ); diff --git a/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql b/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql index e2aa1b3e9..e4e49e4ba 100644 --- a/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql +++ b/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql @@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs ( banned BOOLEAN NOT NULL DEFAULT FALSE, username VARCHAR(100) NOT NULL, auth_type VARCHAR(8) NOT NULL DEFAULT '1FA', - remote_ip VARCHAR(47) NULL DEFAULT NULL, + remote_ip VARCHAR(39) NULL DEFAULT NULL, request_uri TEXT, request_method VARCHAR(8) NOT NULL DEFAULT '', PRIMARY KEY (id) @@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification ( id INTEGER, jti VARCHAR(36), iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, exp TIMESTAMP NOT NULL, - used TIMESTAMP NULL DEFAULT NULL, username VARCHAR(100) NOT NULL, action VARCHAR(50) NOT NULL, + consumed TIMESTAMP NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, PRIMARY KEY (id), UNIQUE (jti) ); diff --git a/internal/storage/provider.go b/internal/storage/provider.go index bd86ac954..b74ed8419 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -18,17 +18,17 @@ type Provider interface { LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) - RemoveIdentityVerification(ctx context.Context, jti string) (err error) + ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error) - UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) + LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 816f39f18..5bb89cec3 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -34,7 +34,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification), - sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification), + sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification), sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), @@ -45,8 +45,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations), sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations), - sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), - sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), + sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), + sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), + sqlSelectU2FDevices: fmt.Sprintf(queryFmtSelectU2FDevices, tableU2FDevices), + + sqlUpdateU2FDevicePublicKey: fmt.Sprintf(queryFmtUpdateU2FDevicePublicKey, tableU2FDevices), + sqlUpdateU2FDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateU2FDevicePublicKeyByUsername, tableU2FDevices), sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices), sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices), @@ -86,7 +90,7 @@ type SQLProvider struct { // Table: identity_verification. sqlInsertIdentityVerification string - sqlDeleteIdentityVerification string + sqlConsumeIdentityVerification string sqlSelectExistsIdentityVerification string // Table: totp_configurations. @@ -99,8 +103,12 @@ type SQLProvider struct { sqlUpdateTOTPConfigSecretByUsername string // Table: u2f_devices. - sqlUpsertU2FDevice string - sqlSelectU2FDevice string + sqlUpsertU2FDevice string + sqlSelectU2FDevice string + sqlSelectU2FDevices string + + sqlUpdateU2FDevicePublicKey string + sqlUpdateU2FDevicePublicKeyByUsername string // Table: duo_devices sqlUpsertDuoDevice string @@ -217,7 +225,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m // SaveIdentityVerification save an identity verification record to the database. func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, - verification.JTI, verification.IssuedAt, verification.ExpiresAt, + verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt, verification.Username, verification.Action); err != nil { return fmt.Errorf("error inserting identity verification: %w", err) } @@ -225,9 +233,9 @@ func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification return nil } -// RemoveIdentityVerification remove an identity verification record from the database. -func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, jti string) (err error) { - if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, jti); err != nil { +// ConsumeIdentityVerification marks an identity verification record in the database as consumed. +func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil { return fmt.Errorf("error updating identity verification: %w", err) } @@ -321,8 +329,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in return configs, nil } -// UpdateTOTPConfigurationSecret updates a TOTP configuration secret. -func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) { +func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) { switch config.ID { case 0: _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username) @@ -339,6 +346,10 @@ func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config // SaveU2FDevice saves a registered U2F device. func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) { + if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil { + return fmt.Errorf("error encrypting the U2F device public key: %v", err) + } + if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.Description, device.KeyHandle, device.PublicKey); err != nil { return fmt.Errorf("error upserting U2F device: %v", err) } @@ -348,9 +359,7 @@ func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice // LoadU2FDevice loads a U2F device registration for a given username. func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) { - device = &models.U2FDevice{ - Username: username, - } + device = &models.U2FDevice{} if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -360,9 +369,64 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic return nil, fmt.Errorf("error selecting U2F device: %w", err) } + if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err) + } + return device, nil } +// LoadU2FDevices loads U2F device registrations. +func (p *SQLProvider) LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, limit, limit*page) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return devices, nil + } + + return nil, fmt.Errorf("error selecting U2F devices: %w", err) + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + devices = make([]models.U2FDevice, 0, limit) + + var device models.U2FDevice + + for rows.Next() { + if err = rows.StructScan(&device); err != nil { + return nil, fmt.Errorf("error scanning U2F device to struct: %w", err) + } + + if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil { + return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err) + } + + devices = append(devices, device) + } + + return devices, nil +} + +func (p *SQLProvider) updateU2FDevicePublicKey(ctx context.Context, device models.U2FDevice) (err error) { + switch device.ID { + case 0: + _, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKeyByUsername, device.PublicKey, device.Username) + default: + _, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKey, device.PublicKey, device.ID) + } + + if err != nil { + return fmt.Errorf("error updating U2F public key: %w", err) + } + + return nil +} + // SavePreferredDuoDevice saves a Duo device. func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) { _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method) diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 79dc4d6de..07868e0e6 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -38,13 +38,16 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo) provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification) provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) - provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification) + provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification) provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig) provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs) provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret) provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername) provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice) + provider.sqlSelectU2FDevices = provider.db.Rebind(provider.sqlSelectU2FDevices) + provider.sqlUpdateU2FDevicePublicKey = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKey) + provider.sqlUpdateU2FDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKeyByUsername) provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice) provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice) provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt) diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go index b7b2d1172..32aae959d 100644 --- a/internal/storage/sql_provider_encryption.go +++ b/internal/storage/sql_provider_encryption.go @@ -22,6 +22,26 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK key := sha256.Sum256([]byte(encryptionKey)) + if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil { + return err + } + + if err = p.schemaEncryptionChangeKeyU2F(ctx, tx, key); err != nil { + return err + } + + if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + } + + return fmt.Errorf("rollback due to error: %w", err) + } + + return tx.Commit() +} + +func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) { var configs []models.TOTPConfiguration for page := 0; true; page++ { @@ -42,7 +62,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK return fmt.Errorf("rollback due to error: %w", err) } - if err = p.UpdateTOTPConfigurationSecret(ctx, config); err != nil { + if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) } @@ -56,15 +76,45 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK } } - if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + return nil +} + +func (p *SQLProvider) schemaEncryptionChangeKeyU2F(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) { + var devices []models.U2FDevice + + for page := 0; true; page++ { + if devices, err = p.LoadU2FDevices(ctx, 10, page); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + } + + return fmt.Errorf("rollback due to error: %w", err) } - return fmt.Errorf("rollback due to error: %w", err) + for _, device := range devices { + if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + } + + return fmt.Errorf("rollback due to error: %w", err) + } + + if err = p.updateU2FDevicePublicKey(ctx, device); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err) + } + + return fmt.Errorf("rollback due to error: %w", err) + } + } + + if len(devices) != 10 { + break + } } - return tx.Commit() + return nil } // SchemaEncryptionCheckKey checks the encryption key configured is valid for the database. @@ -85,49 +135,12 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool } if verbose { - var ( - config models.TOTPConfiguration - row int - invalid int - total int - ) - - pageSize := 10 - - var rows *sqlx.Rows - - for page := 0; true; page++ { - if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil { - _ = rows.Close() - - return fmt.Errorf("error selecting TOTP configurations: %w", err) - } - - row = 0 - - for rows.Next() { - total++ - row++ - - if err = rows.StructScan(&config); err != nil { - _ = rows.Close() - return fmt.Errorf("error scanning TOTP configuration to struct: %w", err) - } - - if _, err = p.decrypt(config.Secret); err != nil { - invalid++ - } - } - - _ = rows.Close() - - if row < pageSize { - break - } + if err = p.schemaEncryptionCheckTOTP(ctx); err != nil { + errs = append(errs, err) } - if invalid != 0 { - errs = append(errs, fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)) + if err = p.schemaEncryptionCheckU2F(ctx); err != nil { + errs = append(errs, err) } } @@ -148,6 +161,104 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool return nil } +func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) { + var ( + config models.TOTPConfiguration + row int + invalid int + total int + ) + + pageSize := 10 + + var rows *sqlx.Rows + + for page := 0; true; page++ { + if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil { + _ = rows.Close() + + return fmt.Errorf("error selecting TOTP configurations: %w", err) + } + + row = 0 + + for rows.Next() { + total++ + row++ + + if err = rows.StructScan(&config); err != nil { + _ = rows.Close() + return fmt.Errorf("error scanning TOTP configuration to struct: %w", err) + } + + if _, err = p.decrypt(config.Secret); err != nil { + invalid++ + } + } + + _ = rows.Close() + + if row < pageSize { + break + } + } + + if invalid != 0 { + return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total) + } + + return nil +} + +func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) { + var ( + device models.U2FDevice + row int + invalid int + total int + ) + + pageSize := 10 + + var rows *sqlx.Rows + + for page := 0; true; page++ { + if rows, err = p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, pageSize, pageSize*page); err != nil { + _ = rows.Close() + + return fmt.Errorf("error selecting U2F devices: %w", err) + } + + row = 0 + + for rows.Next() { + total++ + row++ + + if err = rows.StructScan(&device); err != nil { + _ = rows.Close() + return fmt.Errorf("error scanning U2F device to struct: %w", err) + } + + if _, err = p.decrypt(device.PublicKey); err != nil { + invalid++ + } + } + + _ = rows.Close() + + if row < pageSize { + break + } + } + + if invalid != 0 { + return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total) + } + + return nil +} + func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { return utils.Encrypt(clearText, &p.key) } diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index 9b6e66d4d..5ffbb1cf8 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -60,16 +60,16 @@ const ( SELECT EXISTS ( SELECT id FROM %s - WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND used IS NULL + WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND consumed IS NULL );` queryFmtInsertIdentityVerification = ` - INSERT INTO %s (jti, iat, exp, username, action) - VALUES (?, ?, ?, ?, ?);` + INSERT INTO %s (jti, iat, issued_ip, exp, username, action) + VALUES (?, ?, ?, ?, ?, ?);` - queryFmtDeleteIdentityVerification = ` + queryFmtConsumeIdentityVerification = ` UPDATE %s - SET used = CURRENT_TIMESTAMP + SET consumed = CURRENT_TIMESTAMP, consumed_ip = ? WHERE jti = ?;` ) @@ -114,10 +114,26 @@ const ( const ( queryFmtSelectU2FDevice = ` - SELECT key_handle, public_key + SELECT id, username, key_handle, public_key FROM %s WHERE username = ?;` + queryFmtSelectU2FDevices = ` + SELECT id, username, key_handle, public_key + FROM %s + LIMIT ? + OFFSET ?;` + + queryFmtUpdateU2FDevicePublicKey = ` + UPDATE %s + SET public_key = ? + WHERE id = ?;` + + queryFmtUpdateUpdateU2FDevicePublicKeyByUsername = ` + UPDATE %s + SET public_key = ? + WHERE username = ?;` + queryFmtUpsertU2FDevice = ` REPLACE INTO %s (username, description, key_handle, public_key) VALUES (?, ?, ?, ?);` diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index dbe2bc1a0..9597e07bb 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -2,6 +2,7 @@ package storage import ( "context" + "errors" "fmt" "strconv" "time" @@ -57,9 +58,13 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error return migration.After, nil } - if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) && - utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) && - utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) { + var tablesV1 = []string{tableDuoDevices, tableEncryption, tableIdentityVerification, tableMigrations, tableTOTPConfigurations} + + if utils.IsStringSliceContainsAll(tablesPre1, tables) { + if utils.IsStringSliceContainsAny(tablesV1, tables) { + return -2, errors.New("pre1 schema contains v1 tables it shouldn't contain") + } + return -1, nil } diff --git a/internal/storage/sql_provider_schema_pre1.go b/internal/storage/sql_provider_schema_pre1.go index e577d1f77..5bf3ad08d 100644 --- a/internal/storage/sql_provider_schema_pre1.go +++ b/internal/storage/sql_provider_schema_pre1.go @@ -267,7 +267,12 @@ func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) { return err } - devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey}) + encryptedPublicKey, err := p.encrypt(publicKey) + if err != nil { + return err + } + + devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: encryptedPublicKey}) } for _, device := range devices { @@ -446,6 +451,11 @@ func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) { return err } + device.PublicKey, err = p.decrypt(device.PublicKey) + if err != nil { + return err + } + devices = append(devices, device) } diff --git a/internal/utils/strings.go b/internal/utils/strings.go index eec7ba03b..19e049200 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -91,6 +91,17 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool return true } +// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles. +func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) { + for _, n := range needles { + if IsStringInSlice(n, haystack) { + return true + } + } + + return false +} + // SliceString splits a string s into an array with each item being a max of int d // d = denominator, n = numerator, q = quotient, r = remainder. func SliceString(s string, d int) (array []string) { diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index cc79c50d3..c006f57e3 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -162,3 +162,12 @@ func TestIsStringSliceContainsAll(t *testing.T) { assert.True(t, IsStringSliceContainsAll(needles, haystackOne)) assert.False(t, IsStringSliceContainsAll(needles, haystackTwo)) } + +func TestIsStringSliceContainsAny(t *testing.T) { + needles := []string{"abc", "123", "xyz"} + haystackOne := []string{"tvu", "456", "hij"} + haystackTwo := []string{"tvu", "123", "456", "xyz"} + + assert.False(t, IsStringSliceContainsAny(needles, haystackOne)) + assert.True(t, IsStringSliceContainsAny(needles, haystackTwo)) +}