feat(web): one time password verification

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
feat-otp-verification
James Elliott 2023-04-09 16:54:58 +10:00
parent 91083f0052
commit fe5e07d868
No known key found for this signature in database
GPG Key ID: 0F1C4A096E857E49
26 changed files with 1147 additions and 75 deletions

View File

@ -39,3 +39,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests | | 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
| 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table | | 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table |
| 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes | | 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes |
| 11 | 4.38.0 | One-Time Password for Identity Verification via Email Changes |

View File

@ -24,4 +24,5 @@ type Configuration struct {
WebAuthn WebAuthnConfiguration `koanf:"webauthn"` WebAuthn WebAuthnConfiguration `koanf:"webauthn"`
PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"` PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"`
PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"` PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"`
IdentityValidation IdentityValidation `koanf:"identity_validation"`
} }

View File

@ -0,0 +1,21 @@
package schema
import (
"time"
)
type IdentityValidation struct {
ResetPassword ResetPasswordIdentityValidation `koanf:"reset_password"`
CredentialRegistration CredentialRegistrationIdentityValidation `koanf:"credential_registration"`
}
type ResetPasswordIdentityValidation struct {
EmailExpiration time.Duration `koanf:"email_expiration"`
}
type CredentialRegistrationIdentityValidation struct {
EmailExpiration time.Duration `koanf:"email_expiration"`
ElevationExpiration time.Duration `koanf:"elevation_expiration"`
OTPCharacters int `koanf:"otp_characters"`
Skip2FA bool `koanf:"skip_2fa"`
}

View File

@ -446,7 +446,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
s.mock.Ctx.Request.SetBodyString(`{ s.mock.Ctx.Request.SetBodyString(`{
"username": "test", "username": "test",
"password": "hello", "password": "hello",
"requestMethod": "GET", "requestMethod": fasthttp.MethodGet,
"keepMeLoggedIn": false "keepMeLoggedIn": false
}`) }`)

View File

@ -0,0 +1,235 @@
package middlewares
import (
"fmt"
"time"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/session"
)
// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves.
type OTPEscalationProtectedEndpointConfig struct {
Characters int
EmailValidityDuration time.Duration
EscalationValidityDuration time.Duration
Skip2FA bool
}
type RequiredLevelProtectedEndpointConfig struct {
Level authentication.Level
}
type ProtectedEndpointConfig struct {
OTPEscalation *OTPEscalationProtectedEndpointConfig
RequiredLevel *RequiredLevelProtectedEndpointConfig
}
func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware {
var handlers []ProtectedEndpointHandler
if config.RequiredLevel != nil {
handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level})
}
if config.OTPEscalation != nil {
handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation})
}
return ProtectedEndpoint(handlers...)
}
func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware {
n := len(handlers)
return func(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
session, err := ctx.GetSession()
if err != nil || session.IsAnonymous() {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
return
}
var failed, failedAuthentication, failedElevation bool
for i := 0; i < n; i++ {
if handlers[i].Check(ctx, &session) {
continue
}
failed = true
if handlers[i].IsAuthentication() {
failedAuthentication = true
}
if handlers[i].IsElevation() {
failedElevation = true
}
handlers[i].Failure(ctx, &session)
}
if failed {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation)
return
}
next(ctx)
}
}
}
type ProtectedEndpointHandler interface {
Name() string
Check(ctx *AutheliaCtx, s *session.UserSession) (success bool)
Failure(ctx *AutheliaCtx, s *session.UserSession)
IsAuthentication() bool
IsElevation() bool
}
func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler {
handler := &RequiredLevelProtectedEndpointHandler{
level: level,
statusCode: statusCode,
}
if handler.statusCode == 0 {
handler.statusCode = fasthttp.StatusForbidden
}
if handler.level == 0 {
handler.level = authentication.OneFactor
}
return handler
}
type RequiredLevelProtectedEndpointHandler struct {
level authentication.Level
statusCode int
}
func (h *RequiredLevelProtectedEndpointHandler) Name() string {
return fmt.Sprintf("required_level(%s)", h.level)
}
func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool {
return true
}
func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool {
return false
}
func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
return s.AuthenticationLevel >= h.level
}
func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) {
}
func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler {
return &OTPEscalationProtectedEndpointHandler{
config: &config,
}
}
type OTPEscalationProtectedEndpointHandler struct {
config *OTPEscalationProtectedEndpointConfig
}
func (h *OTPEscalationProtectedEndpointHandler) Name() string {
return "one_time_password"
}
func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool {
return false
}
func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool {
return true
}
func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor {
ctx.Logger.
WithField("username", s.Username).
Warning("User elevated session check has skipped due to 2FA")
return true
}
if s.Elevations.User == nil {
ctx.Logger.
WithField("username", s.Username).
Warning("User elevated session has not been created")
return false
}
if s.Elevations.User.Expires.Before(ctx.Clock.Now()) {
ctx.Logger.
WithField("username", s.Username).
WithField("expires", s.Elevations.User.Expires).
Debug("User elevated session IP did not match the request")
return false
}
if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) {
ctx.Logger.
WithField("username", s.Username).
WithField("elevation_ip", s.Elevations.User.RemoteIP).
Warning("User elevated session IP did not match the request")
return false
}
return true
}
func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) {
if s.Elevations.User != nil {
// If we make it here we should destroy the elevation data.
s.Elevations.User = nil
if err := ctx.SaveSession(*s); err != nil {
ctx.Logger.WithError(err).Error("Error session after user elevated session failure")
}
}
}
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
func Require1FA(next RequestHandler) RequestHandler {
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden))
return handler(next)
}
// Require2FA requires the user to have authenticated with two-factor authentication.
func Require2FA(next RequestHandler) RequestHandler {
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden))
return handler(next)
}
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
session, err := ctx.GetSession()
if err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
return
}
next(ctx)
}
}

View File

@ -0,0 +1,315 @@
package middlewares_test
import (
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/session"
)
func handleSetStatus(code int) middlewares.RequestHandler {
return func(ctx *middlewares.AutheliaCtx) {
if err := ctx.ReplyJSON(middlewares.ErrorResponse{Status: "OK", Message: "Endpoint Response"}, 0); err != nil {
ctx.Logger.Error(err)
}
ctx.SetStatusCode(code)
ctx.Response.Header.Set("X-Testing-Success", "yes")
}
}
func TestProtectedEndpointRequiredLevel(t *testing.T) {
testCases := []struct {
name string
level authentication.Level
have session.UserSession
expected, status int
expectedAuth, expectedElevate bool
}{
{
name: "1FAWithAuthenticatedUser2FAShould200OK",
level: authentication.OneFactor,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "1FAWithAuthenticatedUser1FAShould200OK",
level: authentication.OneFactor,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
{
name: "1FAWithAuthenticatedUser2FAShould301Found",
level: authentication.OneFactor,
expected: fasthttp.StatusFound,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "1FAWithAuthenticatedUser1FAShould301Found",
level: authentication.OneFactor,
expected: fasthttp.StatusFound,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
{
name: "1FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.OneFactor,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusOK,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.OneFactor,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusFound,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.TwoFactor,
expected: fasthttp.StatusForbidden,
expectedAuth: true,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
err := mock.Ctx.SaveSession(tc.have)
require.NoError(t, err)
var h middlewares.RequestHandler
switch tc.level {
case authentication.OneFactor:
h = middlewares.Require1FA(handleSetStatus(tc.status))
default:
h = middlewares.Require2FA(handleSetStatus(tc.status))
}
h(mock.Ctx)
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
if tc.expected == tc.status {
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
} else {
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":%t,"elevation":%t}`, fasthttp.StatusMessage(tc.expected), tc.expectedAuth, tc.expectedElevate), string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
}
})
}
}
func TestProtectedEndpointOTP(t *testing.T) {
testCases := []struct {
name string
characters int
emailexp, sessionexp time.Duration
skip2fa bool
have session.UserSession
ip net.IP
time time.Time
expected, status int
}{
{
name: "ReturnUnauthorizedForAnonymous",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusFound,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "Return200OKWhen2FASkipAndUserIs2FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: true,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "HandleEscalationEmailWhen2FASkipAndUserIs1FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: true,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
Elevations: session.Elevations{User: nil},
},
},
{
name: "HandleEscalationEmailWhenUserIs2FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{User: nil},
},
},
{
name: "Return200OKWhenUserIsEscalated",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
ip: net.ParseIP("192.168.0.1"),
time: time.Unix(1671322337, 0),
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{
User: &session.Elevation{
RemoteIP: net.ParseIP("192.168.0.1"),
Expires: time.Unix(1671322347, 0),
},
},
},
},
{
name: "Return403ForbiddenWhenUserIsEscalatedButInvalidIP",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
ip: net.ParseIP("192.168.0.2"),
time: time.Unix(1671322337, 0),
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{
User: &session.Elevation{
RemoteIP: net.ParseIP("192.168.0.1"),
Expires: time.Unix(1671322347, 0),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, tc.ip.String())
err := mock.Ctx.SaveSession(tc.have)
require.NoError(t, err)
if !tc.time.IsZero() {
mock.Clock.Set(tc.time)
mock.Ctx.Clock = &mock.Clock
}
h := middlewares.ProtectedEndpoint(middlewares.NewOTPEscalationProtectedEndpointHandler(middlewares.OTPEscalationProtectedEndpointConfig{
Characters: tc.characters,
EmailValidityDuration: tc.emailexp,
EscalationValidityDuration: tc.sessionexp,
Skip2FA: tc.skip2fa,
}))(handleSetStatus(tc.status))
h(mock.Ctx)
switch {
case tc.have.IsAnonymous():
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":false,"elevation":false}`, fasthttp.StatusMessage(tc.expected)), string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
case tc.skip2fa && tc.have.AuthenticationLevel == authentication.TwoFactor:
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
case tc.have.Elevations.User == nil || mock.Ctx.Clock.Now().After(tc.have.Elevations.User.Expires) || !tc.ip.Equal(tc.have.Elevations.User.RemoteIP):
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"KO","message":"Forbidden","authentication":false,"elevation":true}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
default:
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
}
})
}
}

View File

@ -1,43 +0,0 @@
package middlewares
import (
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
)
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
func Require1FA(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.OneFactor {
ctx.ReplyForbidden()
return
}
next(ctx)
}
}
// Require2FA requires the user to have authenticated with two-factor authentication.
func Require2FA(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.ReplyForbidden()
return
}
next(ctx)
}
}
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
return
}
next(ctx)
}
}

View File

@ -110,6 +110,20 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
} }
// ConsumeOneTimePassword mocks base method.
func (m *MockStorage) ConsumeOneTimePassword(arg0 context.Context, arg1 *model.OneTimePassword) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConsumeOneTimePassword", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ConsumeOneTimePassword indicates an expected call of ConsumeOneTimePassword.
func (mr *MockStorageMockRecorder) ConsumeOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).ConsumeOneTimePassword), arg0, arg1)
}
// DeactivateOAuth2Session mocks base method. // DeactivateOAuth2Session mocks base method.
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -299,6 +313,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
} }
// LoadOneTimePassword mocks base method.
func (m *MockStorage) LoadOneTimePassword(arg0 context.Context, arg1, arg2, arg3 string) (*model.OneTimePassword, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOneTimePassword", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*model.OneTimePassword)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOneTimePassword indicates an expected call of LoadOneTimePassword.
func (mr *MockStorageMockRecorder) LoadOneTimePassword(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOneTimePassword", reflect.TypeOf((*MockStorage)(nil).LoadOneTimePassword), arg0, arg1, arg2, arg3)
}
// LoadPreferred2FAMethod mocks base method. // LoadPreferred2FAMethod mocks base method.
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -521,6 +550,20 @@ func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, ar
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
} }
// RevokeOneTimePassword mocks base method.
func (m *MockStorage) RevokeOneTimePassword(arg0 context.Context, arg1 uuid.UUID, arg2 model.IP) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeOneTimePassword", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeOneTimePassword indicates an expected call of RevokeOneTimePassword.
func (mr *MockStorageMockRecorder) RevokeOneTimePassword(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).RevokeOneTimePassword), arg0, arg1, arg2)
}
// Rollback mocks base method. // Rollback mocks base method.
func (m *MockStorage) Rollback(arg0 context.Context) error { func (m *MockStorage) Rollback(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -662,6 +705,21 @@ func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
} }
// SaveOneTimePassword mocks base method.
func (m *MockStorage) SaveOneTimePassword(arg0 context.Context, arg1 model.OneTimePassword) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOneTimePassword", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SaveOneTimePassword indicates an expected call of SaveOneTimePassword.
func (mr *MockStorageMockRecorder) SaveOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOneTimePassword", reflect.TypeOf((*MockStorage)(nil).SaveOneTimePassword), arg0, arg1)
}
// SavePreferred2FAMethod mocks base method. // SavePreferred2FAMethod mocks base method.
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -0,0 +1,43 @@
package model
import (
"database/sql"
"net"
"time"
"github.com/google/uuid"
)
const (
OTPIntentElevateUserSession = "eus"
)
// NewOneTimePassword returns a new OneTimePassword.
func NewOneTimePassword(publicID uuid.UUID, username, intent string, iat, exp time.Time, ip net.IP, value []byte) (otp OneTimePassword) {
return OneTimePassword{
PublicID: publicID,
IssuedAt: iat,
ExpiresAt: exp,
Username: username,
Intent: intent,
IssuedIP: NewIP(ip),
Password: value,
}
}
// OneTimePassword represents special one time passwords stored in the database.
type OneTimePassword struct {
ID int `db:"id"`
PublicID uuid.UUID `db:"public_id"`
Signature string `db:"signature"`
IssuedAt time.Time `db:"iat"`
IssuedIP IP `db:"issued_ip"`
ExpiresAt time.Time `db:"exp"`
Username string `db:"username"`
Intent string `db:"intent"`
Consumed sql.NullTime `db:"consumed"`
ConsumedIP NullIP `db:"consumed_ip"`
Revoked sql.NullTime `db:"revoked"`
RevokedIP NullIP `db:"revoked_ip"`
Password []byte `db:"password"`
}

View File

@ -176,7 +176,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
middleware2FA := middlewares.NewBridgeBuilder(*config, providers). middleware2FA := middlewares.NewBridgeBuilder(*config, providers).
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
WithPostMiddlewares(middlewares.Require2FAWithAPIResponse). WithPostMiddlewares(middlewares.Require2FA).
Build() Build()
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET)) r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))

View File

@ -14,6 +14,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/authelia/authelia/v4/internal/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -138,7 +139,9 @@ type TLSServerContext struct {
func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) { func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) {
serverContext = new(TLSServerContext) serverContext = new(TLSServerContext)
providers := middlewares.Providers{} providers := middlewares.Providers{
Random: random.NewMathematical(),
}
providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath}) providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath})
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"net"
"time" "time"
"github.com/fasthttp/session/v2" "github.com/fasthttp/session/v2"
@ -44,6 +45,8 @@ type UserSession struct {
PasswordResetUsername *string PasswordResetUsername *string
RefreshTTL time.Time RefreshTTL time.Time
Elevations Elevations
} }
// TOTP holds the TOTP registration session data. // TOTP holds the TOTP registration session data.
@ -68,3 +71,13 @@ type Identity struct {
Email string Email string
DisplayName string DisplayName string
} }
type Elevations struct {
User *Elevation
}
type Elevation struct {
ID int
RemoteIP net.IP
Expires time.Time
}

View File

@ -8,6 +8,7 @@ const (
tableAuthenticationLogs = "authentication_logs" tableAuthenticationLogs = "authentication_logs"
tableDuoDevices = "duo_devices" tableDuoDevices = "duo_devices"
tableIdentityVerification = "identity_verification" tableIdentityVerification = "identity_verification"
tableOneTimePassword = "one_time_password"
tableTOTPConfigurations = "totp_configurations" tableTOTPConfigurations = "totp_configurations"
tableUserOpaqueIdentifier = "user_opaque_identifier" tableUserOpaqueIdentifier = "user_opaque_identifier"
tableUserPreferences = "user_preferences" tableUserPreferences = "user_preferences"

View File

@ -419,7 +419,7 @@ CREATE TABLE IF NOT EXISTS oauth2_consent_preconfiguration (
revoked BOOLEAN NOT NULL DEFAULT FALSE, revoked BOOLEAN NOT NULL DEFAULT FALSE,
scopes TEXT NOT NULL, scopes TEXT NOT NULL,
audience TEXT NULL, audience TEXT NULL,
CONSTRAINT "oauth2_consent_preconfiguration_subject_fkey" CONSTRAINT oauth2_consent_preconfiguration_subject_fkey
FOREIGN KEY (subject) FOREIGN KEY (subject)
REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT
); );

View File

@ -0,0 +1,2 @@
DROP TABLE IF EXISTS one_time_password;
DROP TABLE IF EXISTS user_elevated_session;

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed TIMESTAMP NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked TIMESTAMP NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BLOB NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
CREATE UNIQUE INDEX one_time_password_signature ON one_time_password (signature);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires TIMESTAMP NOT NULL,
username VARCHAR(100) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id SERIAL CONSTRAINT one_time_password_pkey PRIMARY KEY,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP WITH TIME ZONE NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BYTEA NOT NULL
);
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id SERIAL CONSTRAINT user_elevated_session_pkey PRIMARY KEY,
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires TIMESTAMP WITH TIME ZONE NOT NULL,
username VARCHAR(100) NOT NULL
);
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp DATETIME NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed DATETIME NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked DATETIME NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BLOB NOT NULL
);
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires DATETIME NOT NULL,
username VARCHAR(100) NOT NULL
);
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -9,7 +9,7 @@ import (
const ( const (
// This is the latest schema version for the purpose of tests. // This is the latest schema version for the purpose of tests.
LatestVersion = 10 LatestVersion = 11
) )
func TestShouldObtainCorrectUpMigrations(t *testing.T) { func TestShouldObtainCorrectUpMigrations(t *testing.T) {

View File

@ -32,6 +32,11 @@ type Provider interface {
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error)
ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error)
RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error)
LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error)
SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error)
UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error)
DeleteTOTPConfiguration(ctx context.Context, username string) (err error) DeleteTOTPConfiguration(ctx context.Context, username string) (err error)

View File

@ -29,6 +29,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
driverName: driverName, driverName: driverName,
config: config, config: config,
errOpen: err, errOpen: err,
keys: SQLProviderKeys{
encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
},
log: logging.Logger(), log: logging.Logger(),
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
@ -38,6 +43,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification), sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
sqlInsertOneTimePassword: fmt.Sprintf(queryFmtInsertOTP, tableOneTimePassword),
sqlConsumeOneTimePassword: fmt.Sprintf(queryFmtConsumeOTP, tableOneTimePassword),
sqlRevokeOneTimePassword: fmt.Sprintf(queryFmtRevokeOTP, tableOneTimePassword),
sqlSelectOneTimePassword: fmt.Sprintf(queryFmtSelectOTP, tableOneTimePassword),
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
@ -141,12 +151,15 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
type SQLProvider struct { type SQLProvider struct {
db *sqlx.DB db *sqlx.DB
key [32]byte key [32]byte
name string name string
driverName string driverName string
schema string schema string
config *schema.Configuration config *schema.Configuration
errOpen error errOpen error
keys SQLProviderKeys
log *logrus.Logger log *logrus.Logger
// Table: authentication_logs. // Table: authentication_logs.
@ -158,6 +171,12 @@ type SQLProvider struct {
sqlConsumeIdentityVerification string sqlConsumeIdentityVerification string
sqlSelectIdentityVerification string sqlSelectIdentityVerification string
// Table: one_time_password.
sqlInsertOneTimePassword string
sqlConsumeOneTimePassword string
sqlRevokeOneTimePassword string
sqlSelectOneTimePassword string
// Table: totp_configurations. // Table: totp_configurations.
sqlUpsertTOTPConfig string sqlUpsertTOTPConfig string
sqlDeleteTOTPConfig string sqlDeleteTOTPConfig string
@ -274,6 +293,12 @@ type SQLProvider struct {
sqlFmtRenameTable string sqlFmtRenameTable string
} }
// SQLProviderKeys are the cryptography keys used by a SQLProvider.
type SQLProviderKeys struct {
encryption [32]byte
signature []byte
}
// Close the underlying database connection. // Close the underlying database connection.
func (p *SQLProvider) Close() (err error) { func (p *SQLProvider) Close() (err error) {
return p.db.Close() return p.db.Close()
@ -313,14 +338,19 @@ func (p *SQLProvider) StartupCheck() (err error) {
} }
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err { switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
case nil:
break
case ErrSchemaAlreadyUpToDate: case ErrSchemaAlreadyUpToDate:
p.log.Infof("Storage schema is already up to date") p.log.Infof("Storage schema is already up to date")
return nil
case nil:
return nil
default: default:
return fmt.Errorf("error during schema migrate: %w", err) return fmt.Errorf("error during schema migrate: %w", err)
} }
if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil {
return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err)
}
return nil
} }
// BeginTX begins a transaction. // BeginTX begins a transaction.
@ -816,6 +846,62 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
} }
} }
// SaveOneTimePassword saves a one time password to the database after generating the signature.
func (p *SQLProvider) SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) {
signature = p.hmacSignature([]byte(otp.Username), []byte(otp.Intent), otp.Password)
if otp.Password, err = p.encrypt(otp.Password); err != nil {
return "", fmt.Errorf("error encrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimePassword,
otp.PublicID, signature, otp.IssuedAt, otp.IssuedIP, otp.ExpiresAt,
otp.Username, otp.Intent, otp.Password); err != nil {
return "", fmt.Errorf("error inserting one time password for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
return signature, nil
}
// ConsumeOneTimePassword consumes a one time password using the signature.
func (p *SQLProvider) ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimePassword, otp.Consumed, otp.ConsumedIP, otp.Signature); err != nil {
return fmt.Errorf("error updating one time password (consume): %w", err)
}
return nil
}
// RevokeOneTimePassword revokes a one time password using the public ID.
func (p *SQLProvider) RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimePassword, ip, publicID); err != nil {
return fmt.Errorf("error updating one time password (revoke): %w", err)
}
return nil
}
// LoadOneTimePassword loads a one time password from the database given a username, intent, and password.
func (p *SQLProvider) LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) {
otp = &model.OneTimePassword{}
signature := p.hmacSignature([]byte(username), []byte(intent), []byte(password))
if err = p.db.GetContext(ctx, otp, p.sqlSelectOneTimePassword, signature, username); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("error selecting one time password: %w", err)
}
if otp.Password, err = p.decrypt(otp.Password); err != nil {
return nil, fmt.Errorf("error decrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
return otp, nil
}
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database. // SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
if config.Secret, err = p.encrypt(config.Secret); err != nil { if config.Secret, err = p.encrypt(config.Secret); err != nil {

View File

@ -52,6 +52,11 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification) provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
provider.sqlInsertOneTimePassword = provider.db.Rebind(provider.sqlInsertOneTimePassword)
provider.sqlConsumeOneTimePassword = provider.db.Rebind(provider.sqlConsumeOneTimePassword)
provider.sqlRevokeOneTimePassword = provider.db.Rebind(provider.sqlRevokeOneTimePassword)
provider.sqlSelectOneTimePassword = provider.db.Rebind(provider.sqlSelectOneTimePassword)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn) provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername) provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)

View File

@ -3,7 +3,10 @@ package storage
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"crypto/sha512"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
@ -16,10 +19,10 @@ import (
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided // SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction. // by this command to encrypt the values again and update them using a transaction.
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) { func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
skey := sha256.Sum256([]byte(key)) key := sha256.Sum256([]byte(rawKey))
if bytes.Equal(skey[:], p.key[:]) { if bytes.Equal(key[:], p.keys.encryption[:]) {
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same") return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
} }
@ -33,6 +36,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
} }
encChangeFuncs := []EncryptionChangeKeyFunc{ encChangeFuncs := []EncryptionChangeKeyFunc{
schemaEncryptionChangeKeyOneTimePassword,
schemaEncryptionChangeKeyTOTP, schemaEncryptionChangeKeyTOTP,
schemaEncryptionChangeKeyWebAuthn, schemaEncryptionChangeKeyWebAuthn,
} }
@ -47,8 +51,10 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session)) encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
} }
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
for _, encChangeFunc := range encChangeFuncs { for _, encChangeFunc := range encChangeFuncs {
if err = encChangeFunc(ctx, p, tx, skey); err != nil { if err = encChangeFunc(ctx, p, tx, key); err != nil {
if rerr := tx.Rollback(); rerr != nil { if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err) return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
} }
@ -57,14 +63,6 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
} }
} }
if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
return tx.Commit() return tx.Commit()
} }
@ -89,6 +87,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
if verbose { if verbose {
encCheckFuncs := []EncryptionCheckKeyFunc{ encCheckFuncs := []EncryptionCheckKeyFunc{
schemaEncryptionCheckKeyOneTimePassword,
schemaEncryptionCheckKeyTOTP, schemaEncryptionCheckKeyTOTP,
schemaEncryptionCheckKeyWebAuthn, schemaEncryptionCheckKeyWebAuthn,
} }
@ -103,6 +102,8 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session)) encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
} }
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
for _, encCheckFunc := range encCheckFuncs { for _, encCheckFunc := range encCheckFuncs {
table, tableResult := encCheckFunc(ctx, p) table, tableResult := encCheckFunc(ctx, p)
@ -113,6 +114,46 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
return result, nil return result, nil
} }
func schemaEncryptionChangeKeyOneTimePassword(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimePassword)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encOneTimePassword, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting one time passwords: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTPEncryptedData, tableOneTimePassword))
for _, c := range configs {
if c.OTP, err = provider.decrypt(c.OTP); err != nil {
return fmt.Errorf("error decrypting one time password with id '%d': %w", c.ID, err)
}
if c.OTP, err = utils.Encrypt(c.OTP, &key); err != nil {
return fmt.Errorf("error encrypting one time password with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.OTP, c.ID); err != nil {
return fmt.Errorf("error updating one time password with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) { func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int var count int
@ -134,7 +175,7 @@ func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, t
return fmt.Errorf("error selecting TOTP configurations: %w", err) return fmt.Errorf("error selecting TOTP configurations: %w", err)
} }
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations)) query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
for _, c := range configs { for _, c := range configs {
if c.Secret, err = provider.decrypt(c.Secret); err != nil { if c.Secret, err = provider.decrypt(c.Secret); err != nil {
@ -231,6 +272,77 @@ func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
} }
} }
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encEncryption, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting encyption value: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
for _, c := range configs {
if c.Value, err = provider.decrypt(c.Value); err != nil {
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
}
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionCheckKeyOneTimePassword(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time passwords: %w", err)}
}
var config encOneTimePassword
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time password to struct: %w", err)}
}
if _, err = provider.decrypt(config.OTP); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableOneTimePassword, result
}
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) { func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var ( var (
rows *sqlx.Rows rows *sqlx.Rows
@ -326,12 +438,77 @@ func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
} }
} }
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
}
var config encEncryption
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
}
if _, err = provider.decrypt(config.Value); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableEncryption, result
}
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
return utils.Encrypt(clearText, &p.key) return utils.Encrypt(clearText, &p.keys.encryption)
} }
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) { func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
return utils.Decrypt(cipherText, &p.key) return utils.Decrypt(cipherText, &p.keys.encryption)
}
func (p *SQLProvider) hmacSignature(values ...[]byte) string {
h := hmac.New(sha512.New, p.keys.signature)
for i := 0; i < len(values); i++ {
h.Write(values[i])
}
return fmt.Sprintf("%x", h.Sum(nil))
}
func (p *SQLProvider) getKeySigHMAC(ctx context.Context) (key []byte, err error) {
if key, err = p.getEncryptionValue(ctx, "hmac_signature_key"); err != nil {
if errors.Is(err, sql.ErrNoRows) {
key = make([]byte, sha512.BlockSize)
_, err = rand.Read(key)
if err != nil {
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
}
if err = p.setEncryptionValue(ctx, "hmac_signature_key", key); err != nil {
return nil, err
}
return key, nil
}
return nil, err
}
return key, nil
} }
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) { func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
@ -345,6 +522,18 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
return p.decrypt(encryptedValue) return p.decrypt(encryptedValue)
} }
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
if value, err = p.encrypt(value); err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
return err
}
return nil
}
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) { func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
valueClearText, err := uuid.NewRandom() valueClearText, err := uuid.NewRandom()
if err != nil { if err != nil {

View File

@ -71,6 +71,36 @@ const (
WHERE jti = ?;` WHERE jti = ?;`
) )
const (
queryFmtSelectOTP = `
SELECT id, public_id, signature, iat, issued_ip, exp, username intent, consumed, consumed_ip, revoked, revoked_ip, password
FROM %s
WHERE signature = ? AND username = ?;`
queryFmtInsertOTP = `
INSERT INTO %s (public_id, signature, iat, issued_ip, exp, username, intent, password)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtConsumeOTP = `
UPDATE %s
SET consumed = ?, consumed_ip = ?
WHERE signature = ?;`
queryFmtRevokeOTP = `
UPDATE %s
SET revoked = CURRENT_TIMESTAMP, revoked_ip = ?
WHERE public_id = ?;`
queryFmtSelectOTPEncryptedData = `
SELECT id, password
FROM %s;`
queryFmtUpdateOTPEncryptedData = `
UPDATE %s
SET password = ?
WHERE id = ?;`
)
const ( const (
queryFmtSelectTOTPConfiguration = ` queryFmtSelectTOTPConfiguration = `
SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret
@ -87,8 +117,7 @@ const (
SELECT id, secret SELECT id, secret
FROM %s;` FROM %s;`
//nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials. queryFmtUpdateTOTPConfigurationEncryptedData = `
queryFmtUpdateTOTPConfigurationSecret = `
UPDATE %s UPDATE %s
SET secret = ? SET secret = ?
WHERE id = ?;` WHERE id = ?;`
@ -241,6 +270,14 @@ const (
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (name) ON CONFLICT (name)
DO UPDATE SET value = $2;` DO UPDATE SET value = $2;`
queryFmtSelectEncryptionEncryptedData = `
SELECT id, value
FROM %s;`
queryFmtUpdateEncryptionEncryptedData = `
UPDATE %s
SET value = ?`
) )
const ( const (

View File

@ -251,7 +251,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
if migration.Version == 1 && migration.Up { if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1. // Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil { if err = p.setNewEncryptionCheckValue(ctx, conn, &p.keys.encryption); err != nil {
return err return err
} }
} }

View File

@ -32,14 +32,24 @@ type encOAuth2Session struct {
Session []byte `db:"session_data"` Session []byte `db:"session_data"`
} }
type encOneTimePassword struct {
ID int `db:"id"`
OTP []byte `db:"otp"`
}
type encTOTPConfiguration struct {
ID int `db:"id"`
Secret []byte `db:"secret"`
}
type encWebAuthnDevice struct { type encWebAuthnDevice struct {
ID int `db:"id"` ID int `db:"id"`
PublicKey []byte `db:"public_key"` PublicKey []byte `db:"public_key"`
} }
type encTOTPConfiguration struct { type encEncryption struct {
ID int `db:"id" json:"-"` ID int `db:"id"`
Secret []byte `db:"secret" json:"-"` Value []byte `db:"value"`
} }
// EncryptionValidationResult contains information about the success of a schema encryption validation. // EncryptionValidationResult contains information about the success of a schema encryption validation.