diff --git a/docs/content/en/configuration/storage/migrations.md b/docs/content/en/configuration/storage/migrations.md index 0cec5b2a2..b61cb7835 100644 --- a/docs/content/en/configuration/storage/migrations.md +++ b/docs/content/en/configuration/storage/migrations.md @@ -39,3 +39,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel | 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests | | 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table | | 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes | +| 11 | 4.38.0 | One-Time Password for Identity Verification via Email Changes | diff --git a/internal/configuration/schema/configuration.go b/internal/configuration/schema/configuration.go index 74f3e4faf..1dfac10f0 100644 --- a/internal/configuration/schema/configuration.go +++ b/internal/configuration/schema/configuration.go @@ -24,4 +24,5 @@ type Configuration struct { WebAuthn WebAuthnConfiguration `koanf:"webauthn"` PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"` PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"` + IdentityValidation IdentityValidation `koanf:"identity_validation"` } diff --git a/internal/configuration/schema/identity_validation.go b/internal/configuration/schema/identity_validation.go new file mode 100644 index 000000000..10114e549 --- /dev/null +++ b/internal/configuration/schema/identity_validation.go @@ -0,0 +1,21 @@ +package schema + +import ( + "time" +) + +type IdentityValidation struct { + ResetPassword ResetPasswordIdentityValidation `koanf:"reset_password"` + CredentialRegistration CredentialRegistrationIdentityValidation `koanf:"credential_registration"` +} + +type ResetPasswordIdentityValidation struct { + EmailExpiration time.Duration `koanf:"email_expiration"` +} + +type CredentialRegistrationIdentityValidation struct { + EmailExpiration time.Duration `koanf:"email_expiration"` + ElevationExpiration time.Duration `koanf:"elevation_expiration"` + OTPCharacters int `koanf:"otp_characters"` + Skip2FA bool `koanf:"skip_2fa"` +} diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index 7c59857e7..c58fa7c71 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -446,7 +446,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi s.mock.Ctx.Request.SetBodyString(`{ "username": "test", "password": "hello", - "requestMethod": "GET", + "requestMethod": fasthttp.MethodGet, "keepMeLoggedIn": false }`) diff --git a/internal/middlewares/protected.go b/internal/middlewares/protected.go new file mode 100644 index 000000000..1d2819e62 --- /dev/null +++ b/internal/middlewares/protected.go @@ -0,0 +1,235 @@ +package middlewares + +import ( + "fmt" + "time" + + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/session" +) + +// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves. +type OTPEscalationProtectedEndpointConfig struct { + Characters int + EmailValidityDuration time.Duration + EscalationValidityDuration time.Duration + Skip2FA bool +} + +type RequiredLevelProtectedEndpointConfig struct { + Level authentication.Level +} + +type ProtectedEndpointConfig struct { + OTPEscalation *OTPEscalationProtectedEndpointConfig + RequiredLevel *RequiredLevelProtectedEndpointConfig +} + +func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware { + var handlers []ProtectedEndpointHandler + + if config.RequiredLevel != nil { + handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level}) + } + + if config.OTPEscalation != nil { + handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation}) + } + + return ProtectedEndpoint(handlers...) +} + +func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware { + n := len(handlers) + + return func(next RequestHandler) RequestHandler { + return func(ctx *AutheliaCtx) { + session, err := ctx.GetSession() + + if err != nil || session.IsAnonymous() { + ctx.SetAuthenticationErrorJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false) + + return + } + + var failed, failedAuthentication, failedElevation bool + + for i := 0; i < n; i++ { + if handlers[i].Check(ctx, &session) { + continue + } + + failed = true + + if handlers[i].IsAuthentication() { + failedAuthentication = true + } + + if handlers[i].IsElevation() { + failedElevation = true + } + + handlers[i].Failure(ctx, &session) + } + + if failed { + ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation) + + return + } + + next(ctx) + } + } +} + +type ProtectedEndpointHandler interface { + Name() string + Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) + Failure(ctx *AutheliaCtx, s *session.UserSession) + + IsAuthentication() bool + IsElevation() bool +} + +func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler { + handler := &RequiredLevelProtectedEndpointHandler{ + level: level, + statusCode: statusCode, + } + + if handler.statusCode == 0 { + handler.statusCode = fasthttp.StatusForbidden + } + + if handler.level == 0 { + handler.level = authentication.OneFactor + } + + return handler +} + +type RequiredLevelProtectedEndpointHandler struct { + level authentication.Level + statusCode int +} + +func (h *RequiredLevelProtectedEndpointHandler) Name() string { + return fmt.Sprintf("required_level(%s)", h.level) +} + +func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool { + return true +} + +func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool { + return false +} + +func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) { + return s.AuthenticationLevel >= h.level +} + +func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) { +} + +func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler { + return &OTPEscalationProtectedEndpointHandler{ + config: &config, + } +} + +type OTPEscalationProtectedEndpointHandler struct { + config *OTPEscalationProtectedEndpointConfig +} + +func (h *OTPEscalationProtectedEndpointHandler) Name() string { + return "one_time_password" +} + +func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool { + return false +} + +func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool { + return true +} + +func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) { + if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor { + ctx.Logger. + WithField("username", s.Username). + Warning("User elevated session check has skipped due to 2FA") + + return true + } + + if s.Elevations.User == nil { + ctx.Logger. + WithField("username", s.Username). + Warning("User elevated session has not been created") + + return false + } + + if s.Elevations.User.Expires.Before(ctx.Clock.Now()) { + ctx.Logger. + WithField("username", s.Username). + WithField("expires", s.Elevations.User.Expires). + Debug("User elevated session IP did not match the request") + + return false + } + + if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) { + ctx.Logger. + WithField("username", s.Username). + WithField("elevation_ip", s.Elevations.User.RemoteIP). + Warning("User elevated session IP did not match the request") + + return false + } + + return true +} + +func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) { + if s.Elevations.User != nil { + // If we make it here we should destroy the elevation data. + s.Elevations.User = nil + + if err := ctx.SaveSession(*s); err != nil { + ctx.Logger.WithError(err).Error("Error session after user elevated session failure") + } + } +} + +// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password). +func Require1FA(next RequestHandler) RequestHandler { + handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden)) + + return handler(next) +} + +// Require2FA requires the user to have authenticated with two-factor authentication. +func Require2FA(next RequestHandler) RequestHandler { + handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden)) + + return handler(next) +} + +// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication. +func Require2FAWithAPIResponse(next RequestHandler) RequestHandler { + return func(ctx *AutheliaCtx) { + session, err := ctx.GetSession() + + if err != nil || session.AuthenticationLevel < authentication.TwoFactor { + ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false) + return + } + + next(ctx) + } +} diff --git a/internal/middlewares/protected_test.go b/internal/middlewares/protected_test.go new file mode 100644 index 000000000..fe6c4db92 --- /dev/null +++ b/internal/middlewares/protected_test.go @@ -0,0 +1,315 @@ +package middlewares_test + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/mocks" + "github.com/authelia/authelia/v4/internal/session" +) + +func handleSetStatus(code int) middlewares.RequestHandler { + return func(ctx *middlewares.AutheliaCtx) { + if err := ctx.ReplyJSON(middlewares.ErrorResponse{Status: "OK", Message: "Endpoint Response"}, 0); err != nil { + ctx.Logger.Error(err) + } + + ctx.SetStatusCode(code) + ctx.Response.Header.Set("X-Testing-Success", "yes") + } +} + +func TestProtectedEndpointRequiredLevel(t *testing.T) { + testCases := []struct { + name string + level authentication.Level + have session.UserSession + expected, status int + expectedAuth, expectedElevate bool + }{ + { + name: "1FAWithAuthenticatedUser2FAShould200OK", + level: authentication.OneFactor, + expected: fasthttp.StatusOK, + status: fasthttp.StatusOK, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + }, + }, + { + name: "1FAWithAuthenticatedUser1FAShould200OK", + level: authentication.OneFactor, + expected: fasthttp.StatusOK, + status: fasthttp.StatusOK, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.OneFactor, + }, + }, + { + name: "1FAWithAuthenticatedUser2FAShould301Found", + level: authentication.OneFactor, + expected: fasthttp.StatusFound, + status: fasthttp.StatusFound, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + }, + }, + { + name: "1FAWithAuthenticatedUser1FAShould301Found", + level: authentication.OneFactor, + expected: fasthttp.StatusFound, + status: fasthttp.StatusFound, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.OneFactor, + }, + }, + { + name: "1FAWithNotAuthenticatedUserShould401Unauthenticated", + level: authentication.OneFactor, + expected: fasthttp.StatusUnauthorized, + status: fasthttp.StatusOK, + have: session.UserSession{ + AuthenticationLevel: authentication.NotAuthenticated, + }, + }, + { + name: "2FAWithNotAuthenticatedUserShould401Unauthenticated", + level: authentication.OneFactor, + expected: fasthttp.StatusUnauthorized, + status: fasthttp.StatusFound, + have: session.UserSession{ + AuthenticationLevel: authentication.NotAuthenticated, + }, + }, + { + name: "2FAWithNotAuthenticatedUserShould401Unauthenticated", + level: authentication.TwoFactor, + expected: fasthttp.StatusForbidden, + expectedAuth: true, + status: fasthttp.StatusFound, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.OneFactor, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + err := mock.Ctx.SaveSession(tc.have) + + require.NoError(t, err) + + var h middlewares.RequestHandler + + switch tc.level { + case authentication.OneFactor: + h = middlewares.Require1FA(handleSetStatus(tc.status)) + default: + h = middlewares.Require2FA(handleSetStatus(tc.status)) + } + + h(mock.Ctx) + + assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode()) + + if tc.expected == tc.status { + assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + } else { + assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":%t,"elevation":%t}`, fasthttp.StatusMessage(tc.expected), tc.expectedAuth, tc.expectedElevate), string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + } + }) + } +} + +func TestProtectedEndpointOTP(t *testing.T) { + testCases := []struct { + name string + characters int + emailexp, sessionexp time.Duration + skip2fa bool + have session.UserSession + ip net.IP + time time.Time + expected, status int + }{ + { + name: "ReturnUnauthorizedForAnonymous", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: false, + expected: fasthttp.StatusUnauthorized, + status: fasthttp.StatusFound, + have: session.UserSession{ + AuthenticationLevel: authentication.NotAuthenticated, + }, + }, + { + name: "Return200OKWhen2FASkipAndUserIs2FAd", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: true, + expected: fasthttp.StatusOK, + status: fasthttp.StatusOK, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + }, + }, + { + name: "HandleEscalationEmailWhen2FASkipAndUserIs1FAd", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: true, + expected: fasthttp.StatusForbidden, + status: fasthttp.StatusOK, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.OneFactor, + Elevations: session.Elevations{User: nil}, + }, + }, + { + name: "HandleEscalationEmailWhenUserIs2FAd", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: false, + expected: fasthttp.StatusForbidden, + status: fasthttp.StatusOK, + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + Elevations: session.Elevations{User: nil}, + }, + }, + { + name: "Return200OKWhenUserIsEscalated", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: false, + expected: fasthttp.StatusOK, + status: fasthttp.StatusOK, + ip: net.ParseIP("192.168.0.1"), + time: time.Unix(1671322337, 0), + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + Elevations: session.Elevations{ + User: &session.Elevation{ + RemoteIP: net.ParseIP("192.168.0.1"), + Expires: time.Unix(1671322347, 0), + }, + }, + }, + }, + { + name: "Return403ForbiddenWhenUserIsEscalatedButInvalidIP", + characters: 10, + emailexp: time.Minute, + sessionexp: time.Minute, + skip2fa: false, + expected: fasthttp.StatusForbidden, + status: fasthttp.StatusOK, + ip: net.ParseIP("192.168.0.2"), + time: time.Unix(1671322337, 0), + have: session.UserSession{ + Username: "john", + DisplayName: "John Wick", + Emails: []string{"john.wick@notmessingaround.com"}, + AuthenticationLevel: authentication.TwoFactor, + Elevations: session.Elevations{ + User: &session.Elevation{ + RemoteIP: net.ParseIP("192.168.0.1"), + Expires: time.Unix(1671322347, 0), + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, tc.ip.String()) + + err := mock.Ctx.SaveSession(tc.have) + + require.NoError(t, err) + + if !tc.time.IsZero() { + mock.Clock.Set(tc.time) + mock.Ctx.Clock = &mock.Clock + } + + h := middlewares.ProtectedEndpoint(middlewares.NewOTPEscalationProtectedEndpointHandler(middlewares.OTPEscalationProtectedEndpointConfig{ + Characters: tc.characters, + EmailValidityDuration: tc.emailexp, + EscalationValidityDuration: tc.sessionexp, + Skip2FA: tc.skip2fa, + }))(handleSetStatus(tc.status)) + + h(mock.Ctx) + + switch { + case tc.have.IsAnonymous(): + assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode()) + assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":false,"elevation":false}`, fasthttp.StatusMessage(tc.expected)), string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + case tc.skip2fa && tc.have.AuthenticationLevel == authentication.TwoFactor: + assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode()) + assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + case tc.have.Elevations.User == nil || mock.Ctx.Clock.Now().After(tc.have.Elevations.User.Expires) || !tc.ip.Equal(tc.have.Elevations.User.RemoteIP): + assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode()) + assert.Equal(t, `{"status":"KO","message":"Forbidden","authentication":false,"elevation":true}`, string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + default: + assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode()) + assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body())) + assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success")) + } + }) + } +} diff --git a/internal/middlewares/require_authentication_level.go b/internal/middlewares/require_authentication_level.go deleted file mode 100644 index 2553a4417..000000000 --- a/internal/middlewares/require_authentication_level.go +++ /dev/null @@ -1,43 +0,0 @@ -package middlewares - -import ( - "github.com/valyala/fasthttp" - - "github.com/authelia/authelia/v4/internal/authentication" -) - -// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password). -func Require1FA(next RequestHandler) RequestHandler { - return func(ctx *AutheliaCtx) { - if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.OneFactor { - ctx.ReplyForbidden() - return - } - - next(ctx) - } -} - -// Require2FA requires the user to have authenticated with two-factor authentication. -func Require2FA(next RequestHandler) RequestHandler { - return func(ctx *AutheliaCtx) { - if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor { - ctx.ReplyForbidden() - return - } - - next(ctx) - } -} - -// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication. -func Require2FAWithAPIResponse(next RequestHandler) RequestHandler { - return func(ctx *AutheliaCtx) { - if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor { - ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false) - return - } - - next(ctx) - } -} diff --git a/internal/mocks/storage.go b/internal/mocks/storage.go index bc60c2145..9fc3d2085 100644 --- a/internal/mocks/storage.go +++ b/internal/mocks/storage.go @@ -110,6 +110,20 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2) } +// ConsumeOneTimePassword mocks base method. +func (m *MockStorage) ConsumeOneTimePassword(arg0 context.Context, arg1 *model.OneTimePassword) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConsumeOneTimePassword", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// ConsumeOneTimePassword indicates an expected call of ConsumeOneTimePassword. +func (mr *MockStorageMockRecorder) ConsumeOneTimePassword(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).ConsumeOneTimePassword), arg0, arg1) +} + // DeactivateOAuth2Session mocks base method. func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { m.ctrl.T.Helper() @@ -299,6 +313,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2) } +// LoadOneTimePassword mocks base method. +func (m *MockStorage) LoadOneTimePassword(arg0 context.Context, arg1, arg2, arg3 string) (*model.OneTimePassword, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOneTimePassword", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*model.OneTimePassword) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadOneTimePassword indicates an expected call of LoadOneTimePassword. +func (mr *MockStorageMockRecorder) LoadOneTimePassword(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOneTimePassword", reflect.TypeOf((*MockStorage)(nil).LoadOneTimePassword), arg0, arg1, arg2, arg3) +} + // LoadPreferred2FAMethod mocks base method. func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { m.ctrl.T.Helper() @@ -521,6 +550,20 @@ func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, ar return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2) } +// RevokeOneTimePassword mocks base method. +func (m *MockStorage) RevokeOneTimePassword(arg0 context.Context, arg1 uuid.UUID, arg2 model.IP) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeOneTimePassword", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeOneTimePassword indicates an expected call of RevokeOneTimePassword. +func (mr *MockStorageMockRecorder) RevokeOneTimePassword(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).RevokeOneTimePassword), arg0, arg1, arg2) +} + // Rollback mocks base method. func (m *MockStorage) Rollback(arg0 context.Context) error { m.ctrl.T.Helper() @@ -662,6 +705,21 @@ func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2) } +// SaveOneTimePassword mocks base method. +func (m *MockStorage) SaveOneTimePassword(arg0 context.Context, arg1 model.OneTimePassword) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOneTimePassword", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveOneTimePassword indicates an expected call of SaveOneTimePassword. +func (mr *MockStorageMockRecorder) SaveOneTimePassword(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOneTimePassword", reflect.TypeOf((*MockStorage)(nil).SaveOneTimePassword), arg0, arg1) +} + // SavePreferred2FAMethod mocks base method. func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() diff --git a/internal/model/one_time_password.go b/internal/model/one_time_password.go new file mode 100644 index 000000000..364631662 --- /dev/null +++ b/internal/model/one_time_password.go @@ -0,0 +1,43 @@ +package model + +import ( + "database/sql" + "net" + "time" + + "github.com/google/uuid" +) + +const ( + OTPIntentElevateUserSession = "eus" +) + +// NewOneTimePassword returns a new OneTimePassword. +func NewOneTimePassword(publicID uuid.UUID, username, intent string, iat, exp time.Time, ip net.IP, value []byte) (otp OneTimePassword) { + return OneTimePassword{ + PublicID: publicID, + IssuedAt: iat, + ExpiresAt: exp, + Username: username, + Intent: intent, + IssuedIP: NewIP(ip), + Password: value, + } +} + +// OneTimePassword represents special one time passwords stored in the database. +type OneTimePassword struct { + ID int `db:"id"` + PublicID uuid.UUID `db:"public_id"` + Signature string `db:"signature"` + IssuedAt time.Time `db:"iat"` + IssuedIP IP `db:"issued_ip"` + ExpiresAt time.Time `db:"exp"` + Username string `db:"username"` + Intent string `db:"intent"` + Consumed sql.NullTime `db:"consumed"` + ConsumedIP NullIP `db:"consumed_ip"` + Revoked sql.NullTime `db:"revoked"` + RevokedIP NullIP `db:"revoked_ip"` + Password []byte `db:"password"` +} diff --git a/internal/server/handlers.go b/internal/server/handlers.go index fb7eaf81e..704d32a2c 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -176,7 +176,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers) middleware2FA := middlewares.NewBridgeBuilder(*config, providers). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). - WithPostMiddlewares(middlewares.Require2FAWithAPIResponse). + WithPostMiddlewares(middlewares.Require2FA). Build() r.HEAD("/api/health", middlewareAPI(handlers.HealthGET)) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 18e506a4b..a4d3612e7 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/authelia/authelia/v4/internal/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -138,7 +139,9 @@ type TLSServerContext struct { func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) { serverContext = new(TLSServerContext) - providers := middlewares.Providers{} + providers := middlewares.Providers{ + Random: random.NewMathematical(), + } providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath}) if err != nil { diff --git a/internal/session/types.go b/internal/session/types.go index 5fdfb8334..124cadda4 100644 --- a/internal/session/types.go +++ b/internal/session/types.go @@ -1,6 +1,7 @@ package session import ( + "net" "time" "github.com/fasthttp/session/v2" @@ -44,6 +45,8 @@ type UserSession struct { PasswordResetUsername *string RefreshTTL time.Time + + Elevations Elevations } // TOTP holds the TOTP registration session data. @@ -68,3 +71,13 @@ type Identity struct { Email string DisplayName string } + +type Elevations struct { + User *Elevation +} + +type Elevation struct { + ID int + RemoteIP net.IP + Expires time.Time +} diff --git a/internal/storage/const.go b/internal/storage/const.go index a96fba0f0..0123b2297 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -8,6 +8,7 @@ const ( tableAuthenticationLogs = "authentication_logs" tableDuoDevices = "duo_devices" tableIdentityVerification = "identity_verification" + tableOneTimePassword = "one_time_password" tableTOTPConfigurations = "totp_configurations" tableUserOpaqueIdentifier = "user_opaque_identifier" tableUserPreferences = "user_preferences" diff --git a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql index 1af55ed6b..52d9c5434 100644 --- a/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql +++ b/internal/storage/migrations/V0007.ConsistencyFixes.sqlite.up.sql @@ -419,7 +419,7 @@ CREATE TABLE IF NOT EXISTS oauth2_consent_preconfiguration ( revoked BOOLEAN NOT NULL DEFAULT FALSE, scopes TEXT NOT NULL, audience TEXT NULL, - CONSTRAINT "oauth2_consent_preconfiguration_subject_fkey" + CONSTRAINT oauth2_consent_preconfiguration_subject_fkey FOREIGN KEY (subject) REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT ); diff --git a/internal/storage/migrations/V0011.OneTimePassword.all.down.sql b/internal/storage/migrations/V0011.OneTimePassword.all.down.sql new file mode 100644 index 000000000..47fdc965d --- /dev/null +++ b/internal/storage/migrations/V0011.OneTimePassword.all.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS one_time_password; +DROP TABLE IF EXISTS user_elevated_session; diff --git a/internal/storage/migrations/V0011.OneTimePassword.mysql.up.sql b/internal/storage/migrations/V0011.OneTimePassword.mysql.up.sql new file mode 100644 index 000000000..8a59532a1 --- /dev/null +++ b/internal/storage/migrations/V0011.OneTimePassword.mysql.up.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS one_time_password ( + id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, + public_id CHAR(36) NOT NULL, + signature VARCHAR(128) NOT NULL, + iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, + exp TIMESTAMP NOT NULL, + username VARCHAR(100) NOT NULL, + intent VARCHAR(100) NOT NULL, + consumed TIMESTAMP NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, + revoked TIMESTAMP NULL DEFAULT NULL, + revoked_ip VARCHAR(39) NULL DEFAULT NULL, + password BLOB NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci; + +CREATE UNIQUE INDEX one_time_password_signature ON one_time_password (signature); +CREATE INDEX one_time_password_lookup ON one_time_password (signature, username); + +CREATE TABLE IF NOT EXISTS user_elevated_session ( + id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_ip VARCHAR(39) NOT NULL, + method VARCHAR(10) NOT NULL, + method_id INTEGER NULL, + expires TIMESTAMP NOT NULL, + username VARCHAR(100) NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci; + +CREATE INDEX user_elevated_session_username ON user_elevated_session (username); diff --git a/internal/storage/migrations/V0011.OneTimePassword.postgres.up.sql b/internal/storage/migrations/V0011.OneTimePassword.postgres.up.sql new file mode 100644 index 000000000..8f4e50916 --- /dev/null +++ b/internal/storage/migrations/V0011.OneTimePassword.postgres.up.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS one_time_password ( + id SERIAL CONSTRAINT one_time_password_pkey PRIMARY KEY, + public_id CHAR(36) NOT NULL, + signature VARCHAR(128) NOT NULL, + iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, + exp TIMESTAMP WITH TIME ZONE NOT NULL, + username VARCHAR(100) NOT NULL, + intent VARCHAR(100) NOT NULL, + consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, + revoked TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + revoked_ip VARCHAR(39) NULL DEFAULT NULL, + password BYTEA NOT NULL +); + +CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username); +CREATE INDEX one_time_password_lookup ON one_time_password (signature, username); + +CREATE TABLE IF NOT EXISTS user_elevated_session ( + id SERIAL CONSTRAINT user_elevated_session_pkey PRIMARY KEY, + created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_ip VARCHAR(39) NOT NULL, + method VARCHAR(10) NOT NULL, + method_id INTEGER NULL, + expires TIMESTAMP WITH TIME ZONE NOT NULL, + username VARCHAR(100) NOT NULL +); + +CREATE INDEX user_elevated_session_username ON user_elevated_session (username); diff --git a/internal/storage/migrations/V0011.OneTimePassword.sqlite.up.sql b/internal/storage/migrations/V0011.OneTimePassword.sqlite.up.sql new file mode 100644 index 000000000..23b1266a5 --- /dev/null +++ b/internal/storage/migrations/V0011.OneTimePassword.sqlite.up.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS one_time_password ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + public_id CHAR(36) NOT NULL, + signature VARCHAR(128) NOT NULL, + iat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + issued_ip VARCHAR(39) NOT NULL, + exp DATETIME NOT NULL, + username VARCHAR(100) NOT NULL, + intent VARCHAR(100) NOT NULL, + consumed DATETIME NULL DEFAULT NULL, + consumed_ip VARCHAR(39) NULL DEFAULT NULL, + revoked DATETIME NULL DEFAULT NULL, + revoked_ip VARCHAR(39) NULL DEFAULT NULL, + password BLOB NOT NULL +); + +CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username); +CREATE INDEX one_time_password_lookup ON one_time_password (signature, username); + +CREATE TABLE IF NOT EXISTS user_elevated_session ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_ip VARCHAR(39) NOT NULL, + method VARCHAR(10) NOT NULL, + method_id INTEGER NULL, + expires DATETIME NOT NULL, + username VARCHAR(100) NOT NULL +); + +CREATE INDEX user_elevated_session_username ON user_elevated_session (username); diff --git a/internal/storage/migrations_test.go b/internal/storage/migrations_test.go index d4331555a..fcaeef6af 100644 --- a/internal/storage/migrations_test.go +++ b/internal/storage/migrations_test.go @@ -9,7 +9,7 @@ import ( const ( // This is the latest schema version for the purpose of tests. - LatestVersion = 10 + LatestVersion = 11 ) func TestShouldObtainCorrectUpMigrations(t *testing.T) { diff --git a/internal/storage/provider.go b/internal/storage/provider.go index 800a258e8..0b76b7d48 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -32,6 +32,11 @@ type Provider interface { ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) + SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) + ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) + RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) + LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) + SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index b820e75df..f2c458fc4 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -29,7 +29,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa driverName: driverName, config: config, errOpen: err, - log: logging.Logger(), + + keys: SQLProviderKeys{ + encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)), + }, + + log: logging.Logger(), sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), @@ -38,6 +43,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification), + sqlInsertOneTimePassword: fmt.Sprintf(queryFmtInsertOTP, tableOneTimePassword), + sqlConsumeOneTimePassword: fmt.Sprintf(queryFmtConsumeOTP, tableOneTimePassword), + sqlRevokeOneTimePassword: fmt.Sprintf(queryFmtRevokeOTP, tableOneTimePassword), + sqlSelectOneTimePassword: fmt.Sprintf(queryFmtSelectOTP, tableOneTimePassword), + sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), @@ -139,14 +149,17 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { - db *sqlx.DB - key [32]byte + db *sqlx.DB + key [32]byte + name string driverName string schema string config *schema.Configuration errOpen error + keys SQLProviderKeys + log *logrus.Logger // Table: authentication_logs. @@ -158,6 +171,12 @@ type SQLProvider struct { sqlConsumeIdentityVerification string sqlSelectIdentityVerification string + // Table: one_time_password. + sqlInsertOneTimePassword string + sqlConsumeOneTimePassword string + sqlRevokeOneTimePassword string + sqlSelectOneTimePassword string + // Table: totp_configurations. sqlUpsertTOTPConfig string sqlDeleteTOTPConfig string @@ -274,6 +293,12 @@ type SQLProvider struct { sqlFmtRenameTable string } +// SQLProviderKeys are the cryptography keys used by a SQLProvider. +type SQLProviderKeys struct { + encryption [32]byte + signature []byte +} + // Close the underlying database connection. func (p *SQLProvider) Close() (err error) { return p.db.Close() @@ -313,14 +338,19 @@ func (p *SQLProvider) StartupCheck() (err error) { } switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err { + case nil: + break case ErrSchemaAlreadyUpToDate: p.log.Infof("Storage schema is already up to date") - return nil - case nil: - return nil default: return fmt.Errorf("error during schema migrate: %w", err) } + + if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil { + return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err) + } + + return nil } // BeginTX begins a transaction. @@ -816,6 +846,62 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) } } +// SaveOneTimePassword saves a one time password to the database after generating the signature. +func (p *SQLProvider) SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) { + signature = p.hmacSignature([]byte(otp.Username), []byte(otp.Intent), otp.Password) + + if otp.Password, err = p.encrypt(otp.Password); err != nil { + return "", fmt.Errorf("error encrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err) + } + + if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimePassword, + otp.PublicID, signature, otp.IssuedAt, otp.IssuedIP, otp.ExpiresAt, + otp.Username, otp.Intent, otp.Password); err != nil { + return "", fmt.Errorf("error inserting one time password for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err) + } + + return signature, nil +} + +// ConsumeOneTimePassword consumes a one time password using the signature. +func (p *SQLProvider) ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimePassword, otp.Consumed, otp.ConsumedIP, otp.Signature); err != nil { + return fmt.Errorf("error updating one time password (consume): %w", err) + } + + return nil +} + +// RevokeOneTimePassword revokes a one time password using the public ID. +func (p *SQLProvider) RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimePassword, ip, publicID); err != nil { + return fmt.Errorf("error updating one time password (revoke): %w", err) + } + + return nil +} + +// LoadOneTimePassword loads a one time password from the database given a username, intent, and password. +func (p *SQLProvider) LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) { + otp = &model.OneTimePassword{} + + signature := p.hmacSignature([]byte(username), []byte(intent), []byte(password)) + + if err = p.db.GetContext(ctx, otp, p.sqlSelectOneTimePassword, signature, username); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + + return nil, fmt.Errorf("error selecting one time password: %w", err) + } + + if otp.Password, err = p.decrypt(otp.Password); err != nil { + return nil, fmt.Errorf("error decrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err) + } + + return otp, nil +} + // SaveTOTPConfiguration save a TOTP configuration of a given user in the database. func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { if config.Secret, err = p.encrypt(config.Secret); err != nil { diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 914c72fe7..4a263155f 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -52,6 +52,11 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification) + provider.sqlInsertOneTimePassword = provider.db.Rebind(provider.sqlInsertOneTimePassword) + provider.sqlConsumeOneTimePassword = provider.db.Rebind(provider.sqlConsumeOneTimePassword) + provider.sqlRevokeOneTimePassword = provider.db.Rebind(provider.sqlRevokeOneTimePassword) + provider.sqlSelectOneTimePassword = provider.db.Rebind(provider.sqlSelectOneTimePassword) + provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn) provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername) diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go index 01b9bd848..b8ad9d6e9 100644 --- a/internal/storage/sql_provider_encryption.go +++ b/internal/storage/sql_provider_encryption.go @@ -3,7 +3,10 @@ package storage import ( "bytes" "context" + "crypto/hmac" + "crypto/rand" "crypto/sha256" + "crypto/sha512" "database/sql" "errors" "fmt" @@ -16,10 +19,10 @@ import ( // SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided // by this command to encrypt the values again and update them using a transaction. -func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) { - skey := sha256.Sum256([]byte(key)) +func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) { + key := sha256.Sum256([]byte(rawKey)) - if bytes.Equal(skey[:], p.key[:]) { + if bytes.Equal(key[:], p.keys.encryption[:]) { return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same") } @@ -33,6 +36,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) } encChangeFuncs := []EncryptionChangeKeyFunc{ + schemaEncryptionChangeKeyOneTimePassword, schemaEncryptionChangeKeyTOTP, schemaEncryptionChangeKeyWebAuthn, } @@ -47,8 +51,10 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session)) } + encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption) + for _, encChangeFunc := range encChangeFuncs { - if err = encChangeFunc(ctx, p, tx, skey); err != nil { + if err = encChangeFunc(ctx, p, tx, key); err != nil { if rerr := tx.Rollback(); rerr != nil { return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err) } @@ -57,14 +63,6 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) } } - if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil { - if rerr := tx.Rollback(); rerr != nil { - return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err) - } - - return fmt.Errorf("rollback due to error: %w", err) - } - return tx.Commit() } @@ -89,6 +87,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool if verbose { encCheckFuncs := []EncryptionCheckKeyFunc{ + schemaEncryptionCheckKeyOneTimePassword, schemaEncryptionCheckKeyTOTP, schemaEncryptionCheckKeyWebAuthn, } @@ -103,6 +102,8 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session)) } + encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption) + for _, encCheckFunc := range encCheckFuncs { table, tableResult := encCheckFunc(ctx, p) @@ -113,6 +114,46 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool return result, nil } +func schemaEncryptionChangeKeyOneTimePassword(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { + var count int + + if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimePassword)); err != nil { + return err + } + + if count == 0 { + return nil + } + + configs := make([]encOneTimePassword, 0, count) + + if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil + } + + return fmt.Errorf("error selecting one time passwords: %w", err) + } + + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTPEncryptedData, tableOneTimePassword)) + + for _, c := range configs { + if c.OTP, err = provider.decrypt(c.OTP); err != nil { + return fmt.Errorf("error decrypting one time password with id '%d': %w", c.ID, err) + } + + if c.OTP, err = utils.Encrypt(c.OTP, &key); err != nil { + return fmt.Errorf("error encrypting one time password with id '%d': %w", c.ID, err) + } + + if _, err = tx.ExecContext(ctx, query, c.OTP, c.ID); err != nil { + return fmt.Errorf("error updating one time password with id '%d': %w", c.ID, err) + } + } + + return nil +} + func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { var count int @@ -134,7 +175,7 @@ func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, t return fmt.Errorf("error selecting TOTP configurations: %w", err) } - query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations)) + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations)) for _, c := range configs { if c.Secret, err = provider.decrypt(c.Secret); err != nil { @@ -231,6 +272,77 @@ func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) } } +func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { + var count int + + if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil { + return err + } + + if count == 0 { + return nil + } + + configs := make([]encEncryption, 0, count) + + if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil + } + + return fmt.Errorf("error selecting encyption value: %w", err) + } + + query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption)) + + for _, c := range configs { + if c.Value, err = provider.decrypt(c.Value); err != nil { + return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err) + } + + if c.Value, err = utils.Encrypt(c.Value, &key); err != nil { + return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err) + } + + if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil { + return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err) + } + } + + return nil +} + +func schemaEncryptionCheckKeyOneTimePassword(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { + var ( + rows *sqlx.Rows + err error + ) + + if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil { + return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time passwords: %w", err)} + } + + var config encOneTimePassword + + for rows.Next() { + result.Total++ + + if err = rows.StructScan(&config); err != nil { + _ = rows.Close() + + return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time password to struct: %w", err)} + } + + if _, err = provider.decrypt(config.OTP); err != nil { + result.Invalid++ + } + } + + _ = rows.Close() + + return tableOneTimePassword, result +} + func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { var ( rows *sqlx.Rows @@ -326,12 +438,77 @@ func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType) } } +func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { + var ( + rows *sqlx.Rows + err error + ) + + if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil { + return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)} + } + + var config encEncryption + + for rows.Next() { + result.Total++ + + if err = rows.StructScan(&config); err != nil { + _ = rows.Close() + + return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)} + } + + if _, err = provider.decrypt(config.Value); err != nil { + result.Invalid++ + } + } + + _ = rows.Close() + + return tableEncryption, result +} + func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { - return utils.Encrypt(clearText, &p.key) + return utils.Encrypt(clearText, &p.keys.encryption) } func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) { - return utils.Decrypt(cipherText, &p.key) + return utils.Decrypt(cipherText, &p.keys.encryption) +} + +func (p *SQLProvider) hmacSignature(values ...[]byte) string { + h := hmac.New(sha512.New, p.keys.signature) + + for i := 0; i < len(values); i++ { + h.Write(values[i]) + } + + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func (p *SQLProvider) getKeySigHMAC(ctx context.Context) (key []byte, err error) { + if key, err = p.getEncryptionValue(ctx, "hmac_signature_key"); err != nil { + if errors.Is(err, sql.ErrNoRows) { + key = make([]byte, sha512.BlockSize) + + _, err = rand.Read(key) + + if err != nil { + return nil, fmt.Errorf("failed to generate hmac key: %w", err) + } + + if err = p.setEncryptionValue(ctx, "hmac_signature_key", key); err != nil { + return nil, err + } + + return key, nil + } + + return nil, err + } + + return key, nil } func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) { @@ -345,6 +522,18 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu return p.decrypt(encryptedValue) } +func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) { + if value, err = p.encrypt(value); err != nil { + return err + } + + if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil { + return err + } + + return nil +} + func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) { valueClearText, err := uuid.NewRandom() if err != nil { diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index 941540904..59e945809 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -71,6 +71,36 @@ const ( WHERE jti = ?;` ) +const ( + queryFmtSelectOTP = ` + SELECT id, public_id, signature, iat, issued_ip, exp, username intent, consumed, consumed_ip, revoked, revoked_ip, password + FROM %s + WHERE signature = ? AND username = ?;` + + queryFmtInsertOTP = ` + INSERT INTO %s (public_id, signature, iat, issued_ip, exp, username, intent, password) + VALUES (?, ?, ?, ?, ?, ?, ?, ?);` + + queryFmtConsumeOTP = ` + UPDATE %s + SET consumed = ?, consumed_ip = ? + WHERE signature = ?;` + + queryFmtRevokeOTP = ` + UPDATE %s + SET revoked = CURRENT_TIMESTAMP, revoked_ip = ? + WHERE public_id = ?;` + + queryFmtSelectOTPEncryptedData = ` + SELECT id, password + FROM %s;` + + queryFmtUpdateOTPEncryptedData = ` + UPDATE %s + SET password = ? + WHERE id = ?;` +) + const ( queryFmtSelectTOTPConfiguration = ` SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret @@ -87,8 +117,7 @@ const ( SELECT id, secret FROM %s;` - //nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials. - queryFmtUpdateTOTPConfigurationSecret = ` + queryFmtUpdateTOTPConfigurationEncryptedData = ` UPDATE %s SET secret = ? WHERE id = ?;` @@ -241,6 +270,14 @@ const ( VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $2;` + + queryFmtSelectEncryptionEncryptedData = ` + SELECT id, value + FROM %s;` + + queryFmtUpdateEncryptionEncryptedData = ` + UPDATE %s + SET value = ?` ) const ( diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index c940e6b34..cc6f6825e 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -251,7 +251,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio if migration.Version == 1 && migration.Up { // Add the schema encryption value if upgrading to v1. - if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil { + if err = p.setNewEncryptionCheckValue(ctx, conn, &p.keys.encryption); err != nil { return err } } diff --git a/internal/storage/types.go b/internal/storage/types.go index 7537c7695..7e6e72d1c 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -32,14 +32,24 @@ type encOAuth2Session struct { Session []byte `db:"session_data"` } +type encOneTimePassword struct { + ID int `db:"id"` + OTP []byte `db:"otp"` +} + +type encTOTPConfiguration struct { + ID int `db:"id"` + Secret []byte `db:"secret"` +} + type encWebAuthnDevice struct { ID int `db:"id"` PublicKey []byte `db:"public_key"` } -type encTOTPConfiguration struct { - ID int `db:"id" json:"-"` - Secret []byte `db:"secret" json:"-"` +type encEncryption struct { + ID int `db:"id"` + Value []byte `db:"value"` } // EncryptionValidationResult contains information about the success of a schema encryption validation.