feat(storage): encrypt u2f key (#2664)

Adds encryption to the U2F public keys. While the public keys cannot be used to authenticate, only to validate someone is authenticated, if a rogue operator changed these in the database they may be able to bypass 2FA. This prevents that.
pull/2666/head
James Elliott 2021-12-03 11:04:11 +11:00 committed by GitHub
parent 104a61ecd6
commit 255aaeb2ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 528 additions and 182 deletions

View File

@ -31,12 +31,12 @@ required: yes
{: .label .label-config .label-red }
</div>
The encryption key used to encrypt data in the database. It has a minimum length of 20 and must be provided. We encrypt
data by creating a sha256 checksum of the provided value, and use that to encrypt the data with the AES-GCM 256bit
algorithm.
The encryption key used to encrypt data in the database. We encrypt data by creating a sha256 checksum of the provided
value, and use that to encrypt the data with the AES-GCM 256bit algorithm.
The encrypted data in the database is as follows:
- TOTP Secret
The minimum length of this key is 20 characters, however we generally recommend above 64 characters.
See [securty measures](../../security/measures.md#storage-security-measures) for more information.
### local
See [SQLite](./sqlite.md).

View File

@ -81,6 +81,54 @@ LDAP implementations vary, so please ask if you need some assistance in configur
These protections can be [tuned](../configuration/authentication/ldap.md#refresh-interval) according to your security
policy by changing refresh_interval, however we believe that 5 minutes is a fairly safe interval.
## Storage security measures
We force users to encrypt vulnerable data stored in the database. It is strongly advised you do not give this encryption
key to anyone. In the instance of a database installation that multiple users have access to, you should aim to ensure
that users who have access to the database do not also have access to this key.
The encrypted data in the database is as follows:
|Table |Column |Rational |
|:-----------------:|:--------:|:----------------------------------------------------------------------------------------------------:|
|totp_configurations|secret |Prevents a [Leaked Database](#leaked-database) or [Bad Actors](#bad-actors) from compromising security|
|u2f_devices |public_key|Prevents [Bad Actors](#bad-actors) from compromising security |
### Leaked Database
A leaked database can reasonably compromise security if there are credentials that are not encrypted. Columns encrypted
for this purpose prevent this attack vector.
### Bad Actors
A bad actor who has the SQL password and access to the database can theoretically change another users credential, this
theoretically bypasses authentication. Columns encrypted for this purpose prevent this attack vector.
A bad actor may also be able to use data in the database to bypass 2FA silently depending on the credentials. In the
instance of the U2F public key this is not possible, they can only change it which would eventually alert the user in
question. But in the case of TOTP they can use the secret to authenticate without knowledge of the user in question.
### Encryption key management
You must supply the encryption key in the recommended method of a [secret](../configuration/secrets.md) or in one of
the other methods available for [configuration](../configuration/index.md#configuration).
If you wish to change your encryption key for any reason you can do so using the following steps:
1. Run the `authelia --version` command to determine the version of Authelia you're running and either download that
version or run another container of that version interactively. All the subsequent commands assume you're running
the `authelia` binary in the current working directory. You will have to adjust this according to how you're running
it.
2. Run the `./authelia storage encryption change-key --help` command.
3. Stop Authelia.
- You can skip this step, however note that any data changed between the time you make the change and the time when
you stop Authelia i.e. via user registering a device; will be encrypted with the incorrect key.
4. Run the `./authelia storage encryption change-key` command with the appropriate parameters.
- The help from step 1 will be useful here. The easiest method to accomplish this is with the `--config`,
`--encryption-key`, and `--new-encryption-key` parameters.
5. Update the encryption key Authelia uses on startup.
6. Start Authelia.
## Notifier security measures (SMTP)
The SMTP Notifier implementation does not allow connections that are not secure without changing default configuration

View File

@ -52,7 +52,6 @@ If properly configured, Authelia guarantees the following for security of your u
* Binding session cookies to single IP addresses.
* Authenticate communication between Authelia and reverse proxy.
* Securely transmit authentication data to backends (OAuth2 with bearer tokens).
* Protect secrets stored in the database with encryption to prevent secrets leak by database exfiltration.
* Least privilege on LDAP binding operations (currently administrative user is used to bind while it could be anonymous
for most operations).
* Extend the check of user group memberships to authentication backends other than LDAP (File currently).

View File

@ -65,7 +65,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthType1FA,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.Ctx.Request.SetBodyString(`{
@ -93,7 +93,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthType1FA,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.Ctx.Request.SetBodyString(`{
@ -119,7 +119,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthType1FA,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.Ctx.Request.SetBodyString(`{

View File

@ -34,21 +34,21 @@ func (s *HandlerRegisterU2FStep1Suite) TearDownTest() {
s.mock.Close()
}
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
verification = models.NewIdentityVerification(username, action)
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
verification.ExpiresAt = expiresAt
claims := verification.ToIdentityVerificationClaim()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString([]byte(secret))
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
return ss, verification
}
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -57,7 +57,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
Return(true, nil)
s.mock.StorageMock.EXPECT().
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
Return(nil)
SecondFactorU2FIdentityFinish(s.mock.Ctx)
@ -68,7 +68,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -77,7 +77,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin
Return(true, nil)
s.mock.StorageMock.EXPECT().
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
Return(nil)
SecondFactorU2FIdentityFinish(s.mock.Ctx)

View File

@ -97,7 +97,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -286,7 +286,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -414,7 +414,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -497,7 +497,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -546,7 +546,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -591,7 +591,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -640,7 +640,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)
@ -687,7 +687,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRegenerateSessionForPreventingSessi
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeDuo,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
})).
Return(nil)

View File

@ -51,7 +51,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeTOTP,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
@ -85,7 +85,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeTOTP,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
@ -115,7 +115,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeTOTP,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
@ -146,7 +146,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeTOTP,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.TOTPMock.EXPECT().
@ -180,7 +180,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeTOTP,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.TOTPMock.EXPECT().

View File

@ -51,7 +51,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeU2F,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
@ -83,7 +83,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeU2F,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
bodyBytes, err := json.Marshal(signU2FRequestBody{
@ -111,7 +111,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeU2F,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
bodyBytes, err := json.Marshal(signU2FRequestBody{
@ -142,7 +142,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() {
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeU2F,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
bodyBytes, err := json.Marshal(signU2FRequestBody{
@ -171,7 +171,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRegenerateSessionForPreventingSessi
Banned: false,
Time: s.mock.Clock.Now(),
Type: regulation.AuthTypeU2F,
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
}))
bodyBytes, err := json.Marshal(signU2FRequestBody{

View File

@ -27,7 +27,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
return
}
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim)
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim, ctx.RemoteIP())
// Create the claim with the action to sign it.
claims := verification.ToIdentityVerificationClaim()
@ -183,7 +183,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
return
}
err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, claims.ID)
err = ctx.Providers.StorageProvider.ConsumeIdentityVerification(ctx, claims.ID, models.NewNullIP(ctx.RemoteIP()))
if err != nil {
ctx.Error(err, messageOperationFailed)
return

View File

@ -165,15 +165,15 @@ func (s *IdentityVerificationFinishProcess) TearDownTest() {
s.mock.Close()
}
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
verification = models.NewIdentityVerification(username, action)
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
verification.ExpiresAt = expiresAt
claims := verification.ToIdentityVerificationClaim()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString([]byte(secret))
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
return ss, verification
}
@ -203,7 +203,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided()
}
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "Login",
token, verification := createToken(s.mock, "john", "Login",
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -229,7 +229,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
args := newArgs(defaultRetriever)
token, _ := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", args.ActionClaim,
token, _ := createToken(s.mock, "john", args.ActionClaim,
time.Now().Add(-1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -240,7 +240,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
}
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "", "",
token, verification := createToken(s.mock, "", "",
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -255,7 +255,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
}
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "harry", "EXP_ACTION",
token, verification := createToken(s.mock, "harry", "EXP_ACTION",
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -272,7 +272,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
}
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
token, verification := createToken(s.mock, "john", "EXP_ACTION",
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -281,7 +281,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
Return(true, nil)
s.mock.StorageMock.EXPECT().
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
Return(fmt.Errorf("cannot remove"))
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -291,7 +291,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
}
func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() {
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
token, verification := createToken(s.mock, "john", "EXP_ACTION",
time.Now().Add(1*time.Minute))
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
@ -300,7 +300,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete(
Return(true, nil)
s.mock.StorageMock.EXPECT().
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
Return(nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)

View File

@ -65,6 +65,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
}
// ConsumeIdentityVerification mocks base method.
func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 models.NullIP) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConsumeIdentityVerification", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ConsumeIdentityVerification indicates an expected call of ConsumeIdentityVerification.
func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
}
// DeletePreferredDuoDevice mocks base method.
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
@ -198,6 +212,21 @@ func (mr *MockStorageMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1)
}
// LoadU2FDevices mocks base method.
func (m *MockStorage) LoadU2FDevices(arg0 context.Context, arg1, arg2 int) ([]models.U2FDevice, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadU2FDevices", arg0, arg1, arg2)
ret0, _ := ret[0].([]models.U2FDevice)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadU2FDevices indicates an expected call of LoadU2FDevices.
func (mr *MockStorageMockRecorder) LoadU2FDevices(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevices", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevices), arg0, arg1, arg2)
}
// LoadUserInfo mocks base method.
func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
m.ctrl.T.Helper()
@ -213,20 +242,6 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
}
// RemoveIdentityVerification mocks base method.
func (m *MockStorage) RemoveIdentityVerification(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification.
func (mr *MockStorageMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).RemoveIdentityVerification), arg0, arg1)
}
// SaveIdentityVerification mocks base method.
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
m.ctrl.T.Helper()
@ -442,17 +457,3 @@ func (mr *MockStorageMockRecorder) StartupCheck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck))
}
// UpdateTOTPConfigurationSecret mocks base method.
func (m *MockStorage) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateTOTPConfigurationSecret", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateTOTPConfigurationSecret indicates an expected call of UpdateTOTPConfigurationSecret.
func (mr *MockStorageMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockStorage)(nil).UpdateTOTPConfigurationSecret), arg0, arg1)
}

View File

@ -12,7 +12,7 @@ type AuthenticationAttempt struct {
Banned bool `db:"banned"`
Username string `db:"username"`
Type string `db:"auth_type"`
RemoteIP IPAddress `db:"remote_ip"`
RemoteIP NullIP `db:"remote_ip"`
RequestURI string `db:"request_uri"`
RequestMethod string `db:"request_method"`
}

View File

@ -1,6 +1,7 @@
package models
import (
"net"
"time"
"github.com/golang-jwt/jwt/v4"
@ -8,13 +9,14 @@ import (
)
// NewIdentityVerification creates a new IdentityVerification from a given username and action.
func NewIdentityVerification(username, action string) (verification IdentityVerification) {
func NewIdentityVerification(username, action string, ip net.IP) (verification IdentityVerification) {
return IdentityVerification{
JTI: uuid.New(),
IssuedAt: time.Now(),
ExpiresAt: time.Now().Add(5 * time.Minute),
Action: action,
Username: username,
IssuedIP: NewIP(ip),
}
}
@ -23,10 +25,12 @@ type IdentityVerification struct {
ID int `db:"id"`
JTI uuid.UUID `db:"jti"`
IssuedAt time.Time `db:"iat"`
IssuedIP IP `db:"issued_ip"`
ExpiresAt time.Time `db:"exp"`
Used *time.Time `db:"used"`
Action string `db:"action"`
Username string `db:"username"`
Consumed *time.Time `db:"consumed"`
ConsumedIP NullIP `db:"consumed_ip"`
}
// ToIdentityVerificationClaim converts the IdentityVerification into a IdentityVerificationClaim.

View File

@ -2,23 +2,71 @@ package models
import (
"database/sql/driver"
"errors"
"fmt"
"net"
)
// NewIPAddressFromString converts a string into an IPAddress.
func NewIPAddressFromString(ip string) (ipAddress IPAddress) {
actualIP := net.ParseIP(ip)
return IPAddress{IP: &actualIP}
// NewIP easily constructs a new IP.
func NewIP(value net.IP) (ip IP) {
return IP{IP: value}
}
// IPAddress is a type specific for storage of a net.IP in the database.
type IPAddress struct {
*net.IP
// NewNullIP easily constructs a new NullIP.
func NewNullIP(value net.IP) (ip NullIP) {
return NullIP{IP: value}
}
// Value is the IPAddress implementation of the databases/sql driver.Valuer.
func (ip IPAddress) Value() (value driver.Value, err error) {
// NewNullIPFromString easily constructs a new NullIP from a string.
func NewNullIPFromString(value string) (ip NullIP) {
if value == "" {
return ip
}
return NullIP{IP: net.ParseIP(value)}
}
// IP is a type specific for storage of a net.IP in the database which can't be NULL.
type IP struct {
IP net.IP
}
// Value is the IP implementation of the databases/sql driver.Valuer.
func (ip IP) Value() (value driver.Value, err error) {
if ip.IP == nil {
return nil, errors.New("cannot value nil IP to driver.Value")
}
return driver.Value(ip.IP.String()), nil
}
// Scan is the IP implementation of the sql.Scanner.
func (ip *IP) Scan(src interface{}) (err error) {
if src == nil {
return errors.New("cannot scan nil to type IP")
}
var value string
switch v := src.(type) {
case string:
value = v
default:
return fmt.Errorf("invalid type %T for IP %v", src, src)
}
ip.IP = net.ParseIP(value)
return nil
}
// NullIP is a type specific for storage of a net.IP in the database which can also be NULL.
type NullIP struct {
IP net.IP
}
// Value is the NullIP implementation of the databases/sql driver.Valuer.
func (ip NullIP) Value() (value driver.Value, err error) {
if ip.IP == nil {
return driver.Value(nil), nil
}
@ -26,8 +74,8 @@ func (ip IPAddress) Value() (value driver.Value, err error) {
return driver.Value(ip.IP.String()), nil
}
// Scan is the IPAddress implementation of the sql.Scanner.
func (ip *IPAddress) Scan(src interface{}) (err error) {
// Scan is the NullIP implementation of the sql.Scanner.
func (ip *NullIP) Scan(src interface{}) (err error) {
if src == nil {
ip.IP = nil
return nil
@ -39,10 +87,10 @@ func (ip *IPAddress) Scan(src interface{}) (err error) {
case string:
value = v
default:
return fmt.Errorf("invalid type %T for IPAddress %v", src, src)
return fmt.Errorf("invalid type %T for NullIP %v", src, src)
}
*ip.IP = net.ParseIP(value)
ip.IP = net.ParseIP(value)
return nil
}

View File

@ -51,7 +51,7 @@ func (r *Regulator) Mark(ctx context.Context, successful, banned bool, username,
Banned: banned,
Username: username,
Type: authType,
RemoteIP: models.IPAddress{IP: &remoteIP},
RemoteIP: models.NewNullIP(remoteIP),
RequestURI: requestURI,
RequestMethod: requestMethod,
})

View File

@ -24,8 +24,10 @@ const (
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
const (
tablePre1TOTPSecrets = "totp_secrets"
tablePre1Config = "config"
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
tablePre1Config = "config"
tableAlphaAuthenticationLogs = "AuthenticationLogs"
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
tableAlphaPreferences = "Preferences"
@ -35,6 +37,15 @@ const (
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
)
var tablesPre1 = []string{
tablePre1TOTPSecrets,
tablePre1IdentityVerificationTokens,
tableUserPreferences,
tableU2FDevices,
tableAuthenticationLogs,
}
const (
providerAll = "all"
providerMySQL = "mysql"

View File

@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
banned BOOLEAN NOT NULL DEFAULT FALSE,
username VARCHAR(100) NOT NULL,
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
remote_ip VARCHAR(47) NULL DEFAULT NULL,
remote_ip VARCHAR(39) NULL DEFAULT NULL,
request_uri TEXT NOT NULL,
request_method VARCHAR(8) NOT NULL DEFAULT '',
PRIMARY KEY (id)
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
id INTEGER AUTO_INCREMENT,
jti CHAR(36),
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP NOT NULL,
used TIMESTAMP NULL DEFAULT NULL,
username VARCHAR(100) NOT NULL,
action VARCHAR(50) NOT NULL,
consumed TIMESTAMP NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
PRIMARY KEY (id),
UNIQUE KEY (jti)
);

View File

@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
banned BOOLEAN NOT NULL DEFAULT FALSE,
username VARCHAR(100) NOT NULL,
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
remote_ip VARCHAR(47) NULL DEFAULT NULL,
remote_ip VARCHAR(39) NULL DEFAULT NULL,
request_uri TEXT,
request_method VARCHAR(8) NOT NULL DEFAULT '',
PRIMARY KEY (id)
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
id SERIAL,
jti CHAR(36),
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP WITH TIME ZONE NOT NULL,
used TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
username VARCHAR(100) NOT NULL,
action VARCHAR(50) NOT NULL,
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
PRIMARY KEY (id),
UNIQUE (jti)
);

View File

@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
banned BOOLEAN NOT NULL DEFAULT FALSE,
username VARCHAR(100) NOT NULL,
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
remote_ip VARCHAR(47) NULL DEFAULT NULL,
remote_ip VARCHAR(39) NULL DEFAULT NULL,
request_uri TEXT,
request_method VARCHAR(8) NOT NULL DEFAULT '',
PRIMARY KEY (id)
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
id INTEGER,
jti VARCHAR(36),
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP NOT NULL,
used TIMESTAMP NULL DEFAULT NULL,
username VARCHAR(100) NOT NULL,
action VARCHAR(50) NOT NULL,
consumed TIMESTAMP NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
PRIMARY KEY (id),
UNIQUE (jti)
);

View File

@ -18,17 +18,17 @@ type Provider interface {
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
RemoveIdentityVerification(ctx context.Context, jti string) (err error)
ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error)
UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error)
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error)
SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error)
DeletePreferredDuoDevice(ctx context.Context, username string) (err error)

View File

@ -34,7 +34,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
@ -47,6 +47,10 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
sqlSelectU2FDevices: fmt.Sprintf(queryFmtSelectU2FDevices, tableU2FDevices),
sqlUpdateU2FDevicePublicKey: fmt.Sprintf(queryFmtUpdateU2FDevicePublicKey, tableU2FDevices),
sqlUpdateU2FDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateU2FDevicePublicKeyByUsername, tableU2FDevices),
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
@ -86,7 +90,7 @@ type SQLProvider struct {
// Table: identity_verification.
sqlInsertIdentityVerification string
sqlDeleteIdentityVerification string
sqlConsumeIdentityVerification string
sqlSelectExistsIdentityVerification string
// Table: totp_configurations.
@ -101,6 +105,10 @@ type SQLProvider struct {
// Table: u2f_devices.
sqlUpsertU2FDevice string
sqlSelectU2FDevice string
sqlSelectU2FDevices string
sqlUpdateU2FDevicePublicKey string
sqlUpdateU2FDevicePublicKeyByUsername string
// Table: duo_devices
sqlUpsertDuoDevice string
@ -217,7 +225,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
// SaveIdentityVerification save an identity verification record to the database.
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
verification.JTI, verification.IssuedAt, verification.ExpiresAt,
verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
verification.Username, verification.Action); err != nil {
return fmt.Errorf("error inserting identity verification: %w", err)
}
@ -225,9 +233,9 @@ func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification
return nil
}
// RemoveIdentityVerification remove an identity verification record from the database.
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, jti string) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, jti); err != nil {
// ConsumeIdentityVerification marks an identity verification record in the database as consumed.
func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
return fmt.Errorf("error updating identity verification: %w", err)
}
@ -321,8 +329,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
return configs, nil
}
// UpdateTOTPConfigurationSecret updates a TOTP configuration secret.
func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
switch config.ID {
case 0:
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
@ -339,6 +346,10 @@ func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config
// SaveU2FDevice saves a registered U2F device.
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
return fmt.Errorf("error encrypting the U2F device public key: %v", err)
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.Description, device.KeyHandle, device.PublicKey); err != nil {
return fmt.Errorf("error upserting U2F device: %v", err)
}
@ -348,9 +359,7 @@ func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice
// LoadU2FDevice loads a U2F device registration for a given username.
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
device = &models.U2FDevice{
Username: username,
}
device = &models.U2FDevice{}
if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil {
if errors.Is(err, sql.ErrNoRows) {
@ -360,9 +369,64 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic
return nil, fmt.Errorf("error selecting U2F device: %w", err)
}
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
}
return device, nil
}
// LoadU2FDevices loads U2F device registrations.
func (p *SQLProvider) LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, limit, limit*page)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return devices, nil
}
return nil, fmt.Errorf("error selecting U2F devices: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
devices = make([]models.U2FDevice, 0, limit)
var device models.U2FDevice
for rows.Next() {
if err = rows.StructScan(&device); err != nil {
return nil, fmt.Errorf("error scanning U2F device to struct: %w", err)
}
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
}
devices = append(devices, device)
}
return devices, nil
}
func (p *SQLProvider) updateU2FDevicePublicKey(ctx context.Context, device models.U2FDevice) (err error) {
switch device.ID {
case 0:
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKeyByUsername, device.PublicKey, device.Username)
default:
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKey, device.PublicKey, device.ID)
}
if err != nil {
return fmt.Errorf("error updating U2F public key: %w", err)
}
return nil
}
// SavePreferredDuoDevice saves a Duo device.
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method)

View File

@ -38,13 +38,16 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
provider.sqlSelectU2FDevices = provider.db.Rebind(provider.sqlSelectU2FDevices)
provider.sqlUpdateU2FDevicePublicKey = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKey)
provider.sqlUpdateU2FDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKeyByUsername)
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)

View File

@ -22,6 +22,26 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
key := sha256.Sum256([]byte(encryptionKey))
if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
return err
}
if err = p.schemaEncryptionChangeKeyU2F(ctx, tx, key); err != nil {
return err
}
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
return tx.Commit()
}
func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
var configs []models.TOTPConfiguration
for page := 0; true; page++ {
@ -42,7 +62,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
return fmt.Errorf("rollback due to error: %w", err)
}
if err = p.UpdateTOTPConfigurationSecret(ctx, config); err != nil {
if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
@ -56,7 +76,14 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
}
}
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
return nil
}
func (p *SQLProvider) schemaEncryptionChangeKeyU2F(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
var devices []models.U2FDevice
for page := 0; true; page++ {
if devices, err = p.LoadU2FDevices(ctx, 10, page); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
@ -64,7 +91,30 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
return fmt.Errorf("rollback due to error: %w", err)
}
return tx.Commit()
for _, device := range devices {
if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
if err = p.updateU2FDevicePublicKey(ctx, device); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
}
if len(devices) != 10 {
break
}
}
return nil
}
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
@ -85,6 +135,33 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
}
if verbose {
if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
errs = append(errs, err)
}
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
for i, e := range errs {
if i == 0 {
err = e
continue
}
err = fmt.Errorf("%w, %v", err, e)
}
return err
}
return nil
}
func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
var (
config models.TOTPConfiguration
row int
@ -127,22 +204,56 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
}
if invalid != 0 {
errs = append(errs, fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total))
return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
}
return nil
}
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
var (
device models.U2FDevice
row int
invalid int
total int
)
pageSize := 10
var rows *sqlx.Rows
for page := 0; true; page++ {
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, pageSize, pageSize*page); err != nil {
_ = rows.Close()
return fmt.Errorf("error selecting U2F devices: %w", err)
}
row = 0
for rows.Next() {
total++
row++
if err = rows.StructScan(&device); err != nil {
_ = rows.Close()
return fmt.Errorf("error scanning U2F device to struct: %w", err)
}
if _, err = p.decrypt(device.PublicKey); err != nil {
invalid++
}
}
if len(errs) != 0 {
for i, e := range errs {
if i == 0 {
err = e
_ = rows.Close()
continue
if row < pageSize {
break
}
}
err = fmt.Errorf("%w, %v", err, e)
}
return err
if invalid != 0 {
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
}
return nil

View File

@ -60,16 +60,16 @@ const (
SELECT EXISTS (
SELECT id
FROM %s
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND used IS NULL
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND consumed IS NULL
);`
queryFmtInsertIdentityVerification = `
INSERT INTO %s (jti, iat, exp, username, action)
VALUES (?, ?, ?, ?, ?);`
INSERT INTO %s (jti, iat, issued_ip, exp, username, action)
VALUES (?, ?, ?, ?, ?, ?);`
queryFmtDeleteIdentityVerification = `
queryFmtConsumeIdentityVerification = `
UPDATE %s
SET used = CURRENT_TIMESTAMP
SET consumed = CURRENT_TIMESTAMP, consumed_ip = ?
WHERE jti = ?;`
)
@ -114,10 +114,26 @@ const (
const (
queryFmtSelectU2FDevice = `
SELECT key_handle, public_key
SELECT id, username, key_handle, public_key
FROM %s
WHERE username = ?;`
queryFmtSelectU2FDevices = `
SELECT id, username, key_handle, public_key
FROM %s
LIMIT ?
OFFSET ?;`
queryFmtUpdateU2FDevicePublicKey = `
UPDATE %s
SET public_key = ?
WHERE id = ?;`
queryFmtUpdateUpdateU2FDevicePublicKeyByUsername = `
UPDATE %s
SET public_key = ?
WHERE username = ?;`
queryFmtUpsertU2FDevice = `
REPLACE INTO %s (username, description, key_handle, public_key)
VALUES (?, ?, ?, ?);`

View File

@ -2,6 +2,7 @@ package storage
import (
"context"
"errors"
"fmt"
"strconv"
"time"
@ -57,9 +58,13 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error
return migration.After, nil
}
if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) &&
utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) &&
utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) {
var tablesV1 = []string{tableDuoDevices, tableEncryption, tableIdentityVerification, tableMigrations, tableTOTPConfigurations}
if utils.IsStringSliceContainsAll(tablesPre1, tables) {
if utils.IsStringSliceContainsAny(tablesV1, tables) {
return -2, errors.New("pre1 schema contains v1 tables it shouldn't contain")
}
return -1, nil
}

View File

@ -267,7 +267,12 @@ func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
return err
}
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey})
encryptedPublicKey, err := p.encrypt(publicKey)
if err != nil {
return err
}
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: encryptedPublicKey})
}
for _, device := range devices {
@ -446,6 +451,11 @@ func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
return err
}
device.PublicKey, err = p.decrypt(device.PublicKey)
if err != nil {
return err
}
devices = append(devices, device)
}

View File

@ -91,6 +91,17 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool
return true
}
// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles.
func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) {
for _, n := range needles {
if IsStringInSlice(n, haystack) {
return true
}
}
return false
}
// SliceString splits a string s into an array with each item being a max of int d
// d = denominator, n = numerator, q = quotient, r = remainder.
func SliceString(s string, d int) (array []string) {

View File

@ -162,3 +162,12 @@ func TestIsStringSliceContainsAll(t *testing.T) {
assert.True(t, IsStringSliceContainsAll(needles, haystackOne))
assert.False(t, IsStringSliceContainsAll(needles, haystackTwo))
}
func TestIsStringSliceContainsAny(t *testing.T) {
needles := []string{"abc", "123", "xyz"}
haystackOne := []string{"tvu", "456", "hij"}
haystackTwo := []string{"tvu", "123", "456", "xyz"}
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
}