feat(web): one time password verification

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
feat-otp-verification
James Elliott 2023-04-09 16:54:58 +10:00
parent 91083f0052
commit fe5e07d868
No known key found for this signature in database
GPG Key ID: 0F1C4A096E857E49
26 changed files with 1147 additions and 75 deletions

View File

@ -39,3 +39,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
| 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table |
| 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes |
| 11 | 4.38.0 | One-Time Password for Identity Verification via Email Changes |

View File

@ -24,4 +24,5 @@ type Configuration struct {
WebAuthn WebAuthnConfiguration `koanf:"webauthn"`
PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"`
PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"`
IdentityValidation IdentityValidation `koanf:"identity_validation"`
}

View File

@ -0,0 +1,21 @@
package schema
import (
"time"
)
type IdentityValidation struct {
ResetPassword ResetPasswordIdentityValidation `koanf:"reset_password"`
CredentialRegistration CredentialRegistrationIdentityValidation `koanf:"credential_registration"`
}
type ResetPasswordIdentityValidation struct {
EmailExpiration time.Duration `koanf:"email_expiration"`
}
type CredentialRegistrationIdentityValidation struct {
EmailExpiration time.Duration `koanf:"email_expiration"`
ElevationExpiration time.Duration `koanf:"elevation_expiration"`
OTPCharacters int `koanf:"otp_characters"`
Skip2FA bool `koanf:"skip_2fa"`
}

View File

@ -446,7 +446,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
s.mock.Ctx.Request.SetBodyString(`{
"username": "test",
"password": "hello",
"requestMethod": "GET",
"requestMethod": fasthttp.MethodGet,
"keepMeLoggedIn": false
}`)

View File

@ -0,0 +1,235 @@
package middlewares
import (
"fmt"
"time"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/session"
)
// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves.
type OTPEscalationProtectedEndpointConfig struct {
Characters int
EmailValidityDuration time.Duration
EscalationValidityDuration time.Duration
Skip2FA bool
}
type RequiredLevelProtectedEndpointConfig struct {
Level authentication.Level
}
type ProtectedEndpointConfig struct {
OTPEscalation *OTPEscalationProtectedEndpointConfig
RequiredLevel *RequiredLevelProtectedEndpointConfig
}
func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware {
var handlers []ProtectedEndpointHandler
if config.RequiredLevel != nil {
handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level})
}
if config.OTPEscalation != nil {
handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation})
}
return ProtectedEndpoint(handlers...)
}
func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware {
n := len(handlers)
return func(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
session, err := ctx.GetSession()
if err != nil || session.IsAnonymous() {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
return
}
var failed, failedAuthentication, failedElevation bool
for i := 0; i < n; i++ {
if handlers[i].Check(ctx, &session) {
continue
}
failed = true
if handlers[i].IsAuthentication() {
failedAuthentication = true
}
if handlers[i].IsElevation() {
failedElevation = true
}
handlers[i].Failure(ctx, &session)
}
if failed {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation)
return
}
next(ctx)
}
}
}
type ProtectedEndpointHandler interface {
Name() string
Check(ctx *AutheliaCtx, s *session.UserSession) (success bool)
Failure(ctx *AutheliaCtx, s *session.UserSession)
IsAuthentication() bool
IsElevation() bool
}
func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler {
handler := &RequiredLevelProtectedEndpointHandler{
level: level,
statusCode: statusCode,
}
if handler.statusCode == 0 {
handler.statusCode = fasthttp.StatusForbidden
}
if handler.level == 0 {
handler.level = authentication.OneFactor
}
return handler
}
type RequiredLevelProtectedEndpointHandler struct {
level authentication.Level
statusCode int
}
func (h *RequiredLevelProtectedEndpointHandler) Name() string {
return fmt.Sprintf("required_level(%s)", h.level)
}
func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool {
return true
}
func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool {
return false
}
func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
return s.AuthenticationLevel >= h.level
}
func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) {
}
func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler {
return &OTPEscalationProtectedEndpointHandler{
config: &config,
}
}
type OTPEscalationProtectedEndpointHandler struct {
config *OTPEscalationProtectedEndpointConfig
}
func (h *OTPEscalationProtectedEndpointHandler) Name() string {
return "one_time_password"
}
func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool {
return false
}
func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool {
return true
}
func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor {
ctx.Logger.
WithField("username", s.Username).
Warning("User elevated session check has skipped due to 2FA")
return true
}
if s.Elevations.User == nil {
ctx.Logger.
WithField("username", s.Username).
Warning("User elevated session has not been created")
return false
}
if s.Elevations.User.Expires.Before(ctx.Clock.Now()) {
ctx.Logger.
WithField("username", s.Username).
WithField("expires", s.Elevations.User.Expires).
Debug("User elevated session IP did not match the request")
return false
}
if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) {
ctx.Logger.
WithField("username", s.Username).
WithField("elevation_ip", s.Elevations.User.RemoteIP).
Warning("User elevated session IP did not match the request")
return false
}
return true
}
func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) {
if s.Elevations.User != nil {
// If we make it here we should destroy the elevation data.
s.Elevations.User = nil
if err := ctx.SaveSession(*s); err != nil {
ctx.Logger.WithError(err).Error("Error session after user elevated session failure")
}
}
}
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
func Require1FA(next RequestHandler) RequestHandler {
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden))
return handler(next)
}
// Require2FA requires the user to have authenticated with two-factor authentication.
func Require2FA(next RequestHandler) RequestHandler {
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden))
return handler(next)
}
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
session, err := ctx.GetSession()
if err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
return
}
next(ctx)
}
}

View File

@ -0,0 +1,315 @@
package middlewares_test
import (
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/session"
)
func handleSetStatus(code int) middlewares.RequestHandler {
return func(ctx *middlewares.AutheliaCtx) {
if err := ctx.ReplyJSON(middlewares.ErrorResponse{Status: "OK", Message: "Endpoint Response"}, 0); err != nil {
ctx.Logger.Error(err)
}
ctx.SetStatusCode(code)
ctx.Response.Header.Set("X-Testing-Success", "yes")
}
}
func TestProtectedEndpointRequiredLevel(t *testing.T) {
testCases := []struct {
name string
level authentication.Level
have session.UserSession
expected, status int
expectedAuth, expectedElevate bool
}{
{
name: "1FAWithAuthenticatedUser2FAShould200OK",
level: authentication.OneFactor,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "1FAWithAuthenticatedUser1FAShould200OK",
level: authentication.OneFactor,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
{
name: "1FAWithAuthenticatedUser2FAShould301Found",
level: authentication.OneFactor,
expected: fasthttp.StatusFound,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "1FAWithAuthenticatedUser1FAShould301Found",
level: authentication.OneFactor,
expected: fasthttp.StatusFound,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
{
name: "1FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.OneFactor,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusOK,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.OneFactor,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusFound,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
level: authentication.TwoFactor,
expected: fasthttp.StatusForbidden,
expectedAuth: true,
status: fasthttp.StatusFound,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
err := mock.Ctx.SaveSession(tc.have)
require.NoError(t, err)
var h middlewares.RequestHandler
switch tc.level {
case authentication.OneFactor:
h = middlewares.Require1FA(handleSetStatus(tc.status))
default:
h = middlewares.Require2FA(handleSetStatus(tc.status))
}
h(mock.Ctx)
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
if tc.expected == tc.status {
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
} else {
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":%t,"elevation":%t}`, fasthttp.StatusMessage(tc.expected), tc.expectedAuth, tc.expectedElevate), string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
}
})
}
}
func TestProtectedEndpointOTP(t *testing.T) {
testCases := []struct {
name string
characters int
emailexp, sessionexp time.Duration
skip2fa bool
have session.UserSession
ip net.IP
time time.Time
expected, status int
}{
{
name: "ReturnUnauthorizedForAnonymous",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusUnauthorized,
status: fasthttp.StatusFound,
have: session.UserSession{
AuthenticationLevel: authentication.NotAuthenticated,
},
},
{
name: "Return200OKWhen2FASkipAndUserIs2FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: true,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
},
},
{
name: "HandleEscalationEmailWhen2FASkipAndUserIs1FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: true,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.OneFactor,
Elevations: session.Elevations{User: nil},
},
},
{
name: "HandleEscalationEmailWhenUserIs2FAd",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{User: nil},
},
},
{
name: "Return200OKWhenUserIsEscalated",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusOK,
status: fasthttp.StatusOK,
ip: net.ParseIP("192.168.0.1"),
time: time.Unix(1671322337, 0),
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{
User: &session.Elevation{
RemoteIP: net.ParseIP("192.168.0.1"),
Expires: time.Unix(1671322347, 0),
},
},
},
},
{
name: "Return403ForbiddenWhenUserIsEscalatedButInvalidIP",
characters: 10,
emailexp: time.Minute,
sessionexp: time.Minute,
skip2fa: false,
expected: fasthttp.StatusForbidden,
status: fasthttp.StatusOK,
ip: net.ParseIP("192.168.0.2"),
time: time.Unix(1671322337, 0),
have: session.UserSession{
Username: "john",
DisplayName: "John Wick",
Emails: []string{"john.wick@notmessingaround.com"},
AuthenticationLevel: authentication.TwoFactor,
Elevations: session.Elevations{
User: &session.Elevation{
RemoteIP: net.ParseIP("192.168.0.1"),
Expires: time.Unix(1671322347, 0),
},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, tc.ip.String())
err := mock.Ctx.SaveSession(tc.have)
require.NoError(t, err)
if !tc.time.IsZero() {
mock.Clock.Set(tc.time)
mock.Ctx.Clock = &mock.Clock
}
h := middlewares.ProtectedEndpoint(middlewares.NewOTPEscalationProtectedEndpointHandler(middlewares.OTPEscalationProtectedEndpointConfig{
Characters: tc.characters,
EmailValidityDuration: tc.emailexp,
EscalationValidityDuration: tc.sessionexp,
Skip2FA: tc.skip2fa,
}))(handleSetStatus(tc.status))
h(mock.Ctx)
switch {
case tc.have.IsAnonymous():
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":false,"elevation":false}`, fasthttp.StatusMessage(tc.expected)), string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
case tc.skip2fa && tc.have.AuthenticationLevel == authentication.TwoFactor:
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
case tc.have.Elevations.User == nil || mock.Ctx.Clock.Now().After(tc.have.Elevations.User.Expires) || !tc.ip.Equal(tc.have.Elevations.User.RemoteIP):
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"KO","message":"Forbidden","authentication":false,"elevation":true}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
default:
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
}
})
}
}

View File

@ -1,43 +0,0 @@
package middlewares
import (
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
)
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
func Require1FA(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.OneFactor {
ctx.ReplyForbidden()
return
}
next(ctx)
}
}
// Require2FA requires the user to have authenticated with two-factor authentication.
func Require2FA(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.ReplyForbidden()
return
}
next(ctx)
}
}
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
return func(ctx *AutheliaCtx) {
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
return
}
next(ctx)
}
}

View File

@ -110,6 +110,20 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
}
// ConsumeOneTimePassword mocks base method.
func (m *MockStorage) ConsumeOneTimePassword(arg0 context.Context, arg1 *model.OneTimePassword) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConsumeOneTimePassword", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ConsumeOneTimePassword indicates an expected call of ConsumeOneTimePassword.
func (mr *MockStorageMockRecorder) ConsumeOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).ConsumeOneTimePassword), arg0, arg1)
}
// DeactivateOAuth2Session mocks base method.
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper()
@ -299,6 +313,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
}
// LoadOneTimePassword mocks base method.
func (m *MockStorage) LoadOneTimePassword(arg0 context.Context, arg1, arg2, arg3 string) (*model.OneTimePassword, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOneTimePassword", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*model.OneTimePassword)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOneTimePassword indicates an expected call of LoadOneTimePassword.
func (mr *MockStorageMockRecorder) LoadOneTimePassword(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOneTimePassword", reflect.TypeOf((*MockStorage)(nil).LoadOneTimePassword), arg0, arg1, arg2, arg3)
}
// LoadPreferred2FAMethod mocks base method.
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
@ -521,6 +550,20 @@ func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, ar
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
}
// RevokeOneTimePassword mocks base method.
func (m *MockStorage) RevokeOneTimePassword(arg0 context.Context, arg1 uuid.UUID, arg2 model.IP) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeOneTimePassword", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeOneTimePassword indicates an expected call of RevokeOneTimePassword.
func (mr *MockStorageMockRecorder) RevokeOneTimePassword(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).RevokeOneTimePassword), arg0, arg1, arg2)
}
// Rollback mocks base method.
func (m *MockStorage) Rollback(arg0 context.Context) error {
m.ctrl.T.Helper()
@ -662,6 +705,21 @@ func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
}
// SaveOneTimePassword mocks base method.
func (m *MockStorage) SaveOneTimePassword(arg0 context.Context, arg1 model.OneTimePassword) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOneTimePassword", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SaveOneTimePassword indicates an expected call of SaveOneTimePassword.
func (mr *MockStorageMockRecorder) SaveOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOneTimePassword", reflect.TypeOf((*MockStorage)(nil).SaveOneTimePassword), arg0, arg1)
}
// SavePreferred2FAMethod mocks base method.
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()

View File

@ -0,0 +1,43 @@
package model
import (
"database/sql"
"net"
"time"
"github.com/google/uuid"
)
const (
OTPIntentElevateUserSession = "eus"
)
// NewOneTimePassword returns a new OneTimePassword.
func NewOneTimePassword(publicID uuid.UUID, username, intent string, iat, exp time.Time, ip net.IP, value []byte) (otp OneTimePassword) {
return OneTimePassword{
PublicID: publicID,
IssuedAt: iat,
ExpiresAt: exp,
Username: username,
Intent: intent,
IssuedIP: NewIP(ip),
Password: value,
}
}
// OneTimePassword represents special one time passwords stored in the database.
type OneTimePassword struct {
ID int `db:"id"`
PublicID uuid.UUID `db:"public_id"`
Signature string `db:"signature"`
IssuedAt time.Time `db:"iat"`
IssuedIP IP `db:"issued_ip"`
ExpiresAt time.Time `db:"exp"`
Username string `db:"username"`
Intent string `db:"intent"`
Consumed sql.NullTime `db:"consumed"`
ConsumedIP NullIP `db:"consumed_ip"`
Revoked sql.NullTime `db:"revoked"`
RevokedIP NullIP `db:"revoked_ip"`
Password []byte `db:"password"`
}

View File

@ -176,7 +176,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
middleware2FA := middlewares.NewBridgeBuilder(*config, providers).
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
WithPostMiddlewares(middlewares.Require2FAWithAPIResponse).
WithPostMiddlewares(middlewares.Require2FA).
Build()
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))

View File

@ -14,6 +14,7 @@ import (
"testing"
"time"
"github.com/authelia/authelia/v4/internal/random"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
@ -138,7 +139,9 @@ type TLSServerContext struct {
func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) {
serverContext = new(TLSServerContext)
providers := middlewares.Providers{}
providers := middlewares.Providers{
Random: random.NewMathematical(),
}
providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath})
if err != nil {

View File

@ -1,6 +1,7 @@
package session
import (
"net"
"time"
"github.com/fasthttp/session/v2"
@ -44,6 +45,8 @@ type UserSession struct {
PasswordResetUsername *string
RefreshTTL time.Time
Elevations Elevations
}
// TOTP holds the TOTP registration session data.
@ -68,3 +71,13 @@ type Identity struct {
Email string
DisplayName string
}
type Elevations struct {
User *Elevation
}
type Elevation struct {
ID int
RemoteIP net.IP
Expires time.Time
}

View File

@ -8,6 +8,7 @@ const (
tableAuthenticationLogs = "authentication_logs"
tableDuoDevices = "duo_devices"
tableIdentityVerification = "identity_verification"
tableOneTimePassword = "one_time_password"
tableTOTPConfigurations = "totp_configurations"
tableUserOpaqueIdentifier = "user_opaque_identifier"
tableUserPreferences = "user_preferences"

View File

@ -419,7 +419,7 @@ CREATE TABLE IF NOT EXISTS oauth2_consent_preconfiguration (
revoked BOOLEAN NOT NULL DEFAULT FALSE,
scopes TEXT NOT NULL,
audience TEXT NULL,
CONSTRAINT "oauth2_consent_preconfiguration_subject_fkey"
CONSTRAINT oauth2_consent_preconfiguration_subject_fkey
FOREIGN KEY (subject)
REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT
);

View File

@ -0,0 +1,2 @@
DROP TABLE IF EXISTS one_time_password;
DROP TABLE IF EXISTS user_elevated_session;

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed TIMESTAMP NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked TIMESTAMP NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BLOB NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
CREATE UNIQUE INDEX one_time_password_signature ON one_time_password (signature);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires TIMESTAMP NOT NULL,
username VARCHAR(100) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id SERIAL CONSTRAINT one_time_password_pkey PRIMARY KEY,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp TIMESTAMP WITH TIME ZONE NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BYTEA NOT NULL
);
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id SERIAL CONSTRAINT user_elevated_session_pkey PRIMARY KEY,
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires TIMESTAMP WITH TIME ZONE NOT NULL,
username VARCHAR(100) NOT NULL
);
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -0,0 +1,30 @@
CREATE TABLE IF NOT EXISTS one_time_password (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
public_id CHAR(36) NOT NULL,
signature VARCHAR(128) NOT NULL,
iat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
issued_ip VARCHAR(39) NOT NULL,
exp DATETIME NOT NULL,
username VARCHAR(100) NOT NULL,
intent VARCHAR(100) NOT NULL,
consumed DATETIME NULL DEFAULT NULL,
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
revoked DATETIME NULL DEFAULT NULL,
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
password BLOB NOT NULL
);
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
CREATE TABLE IF NOT EXISTS user_elevated_session (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_ip VARCHAR(39) NOT NULL,
method VARCHAR(10) NOT NULL,
method_id INTEGER NULL,
expires DATETIME NOT NULL,
username VARCHAR(100) NOT NULL
);
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);

View File

@ -9,7 +9,7 @@ import (
const (
// This is the latest schema version for the purpose of tests.
LatestVersion = 10
LatestVersion = 11
)
func TestShouldObtainCorrectUpMigrations(t *testing.T) {

View File

@ -32,6 +32,11 @@ type Provider interface {
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error)
ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error)
RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error)
LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error)
SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error)
UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error)
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)

View File

@ -29,7 +29,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
driverName: driverName,
config: config,
errOpen: err,
log: logging.Logger(),
keys: SQLProviderKeys{
encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
},
log: logging.Logger(),
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
@ -38,6 +43,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
sqlInsertOneTimePassword: fmt.Sprintf(queryFmtInsertOTP, tableOneTimePassword),
sqlConsumeOneTimePassword: fmt.Sprintf(queryFmtConsumeOTP, tableOneTimePassword),
sqlRevokeOneTimePassword: fmt.Sprintf(queryFmtRevokeOTP, tableOneTimePassword),
sqlSelectOneTimePassword: fmt.Sprintf(queryFmtSelectOTP, tableOneTimePassword),
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
@ -139,14 +149,17 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
// SQLProvider is a storage provider persisting data in a SQL database.
type SQLProvider struct {
db *sqlx.DB
key [32]byte
db *sqlx.DB
key [32]byte
name string
driverName string
schema string
config *schema.Configuration
errOpen error
keys SQLProviderKeys
log *logrus.Logger
// Table: authentication_logs.
@ -158,6 +171,12 @@ type SQLProvider struct {
sqlConsumeIdentityVerification string
sqlSelectIdentityVerification string
// Table: one_time_password.
sqlInsertOneTimePassword string
sqlConsumeOneTimePassword string
sqlRevokeOneTimePassword string
sqlSelectOneTimePassword string
// Table: totp_configurations.
sqlUpsertTOTPConfig string
sqlDeleteTOTPConfig string
@ -274,6 +293,12 @@ type SQLProvider struct {
sqlFmtRenameTable string
}
// SQLProviderKeys are the cryptography keys used by a SQLProvider.
type SQLProviderKeys struct {
encryption [32]byte
signature []byte
}
// Close the underlying database connection.
func (p *SQLProvider) Close() (err error) {
return p.db.Close()
@ -313,14 +338,19 @@ func (p *SQLProvider) StartupCheck() (err error) {
}
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
case nil:
break
case ErrSchemaAlreadyUpToDate:
p.log.Infof("Storage schema is already up to date")
return nil
case nil:
return nil
default:
return fmt.Errorf("error during schema migrate: %w", err)
}
if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil {
return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err)
}
return nil
}
// BeginTX begins a transaction.
@ -816,6 +846,62 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
}
}
// SaveOneTimePassword saves a one time password to the database after generating the signature.
func (p *SQLProvider) SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) {
signature = p.hmacSignature([]byte(otp.Username), []byte(otp.Intent), otp.Password)
if otp.Password, err = p.encrypt(otp.Password); err != nil {
return "", fmt.Errorf("error encrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimePassword,
otp.PublicID, signature, otp.IssuedAt, otp.IssuedIP, otp.ExpiresAt,
otp.Username, otp.Intent, otp.Password); err != nil {
return "", fmt.Errorf("error inserting one time password for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
return signature, nil
}
// ConsumeOneTimePassword consumes a one time password using the signature.
func (p *SQLProvider) ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimePassword, otp.Consumed, otp.ConsumedIP, otp.Signature); err != nil {
return fmt.Errorf("error updating one time password (consume): %w", err)
}
return nil
}
// RevokeOneTimePassword revokes a one time password using the public ID.
func (p *SQLProvider) RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimePassword, ip, publicID); err != nil {
return fmt.Errorf("error updating one time password (revoke): %w", err)
}
return nil
}
// LoadOneTimePassword loads a one time password from the database given a username, intent, and password.
func (p *SQLProvider) LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) {
otp = &model.OneTimePassword{}
signature := p.hmacSignature([]byte(username), []byte(intent), []byte(password))
if err = p.db.GetContext(ctx, otp, p.sqlSelectOneTimePassword, signature, username); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("error selecting one time password: %w", err)
}
if otp.Password, err = p.decrypt(otp.Password); err != nil {
return nil, fmt.Errorf("error decrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
}
return otp, nil
}
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
if config.Secret, err = p.encrypt(config.Secret); err != nil {

View File

@ -52,6 +52,11 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
provider.sqlInsertOneTimePassword = provider.db.Rebind(provider.sqlInsertOneTimePassword)
provider.sqlConsumeOneTimePassword = provider.db.Rebind(provider.sqlConsumeOneTimePassword)
provider.sqlRevokeOneTimePassword = provider.db.Rebind(provider.sqlRevokeOneTimePassword)
provider.sqlSelectOneTimePassword = provider.db.Rebind(provider.sqlSelectOneTimePassword)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)

View File

@ -3,7 +3,10 @@ package storage
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"database/sql"
"errors"
"fmt"
@ -16,10 +19,10 @@ import (
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction.
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) {
skey := sha256.Sum256([]byte(key))
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
key := sha256.Sum256([]byte(rawKey))
if bytes.Equal(skey[:], p.key[:]) {
if bytes.Equal(key[:], p.keys.encryption[:]) {
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
}
@ -33,6 +36,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
}
encChangeFuncs := []EncryptionChangeKeyFunc{
schemaEncryptionChangeKeyOneTimePassword,
schemaEncryptionChangeKeyTOTP,
schemaEncryptionChangeKeyWebAuthn,
}
@ -47,8 +51,10 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
}
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
for _, encChangeFunc := range encChangeFuncs {
if err = encChangeFunc(ctx, p, tx, skey); err != nil {
if err = encChangeFunc(ctx, p, tx, key); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
@ -57,14 +63,6 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
}
}
if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
if rerr := tx.Rollback(); rerr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
return tx.Commit()
}
@ -89,6 +87,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
if verbose {
encCheckFuncs := []EncryptionCheckKeyFunc{
schemaEncryptionCheckKeyOneTimePassword,
schemaEncryptionCheckKeyTOTP,
schemaEncryptionCheckKeyWebAuthn,
}
@ -103,6 +102,8 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
}
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
for _, encCheckFunc := range encCheckFuncs {
table, tableResult := encCheckFunc(ctx, p)
@ -113,6 +114,46 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
return result, nil
}
func schemaEncryptionChangeKeyOneTimePassword(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimePassword)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encOneTimePassword, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting one time passwords: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTPEncryptedData, tableOneTimePassword))
for _, c := range configs {
if c.OTP, err = provider.decrypt(c.OTP); err != nil {
return fmt.Errorf("error decrypting one time password with id '%d': %w", c.ID, err)
}
if c.OTP, err = utils.Encrypt(c.OTP, &key); err != nil {
return fmt.Errorf("error encrypting one time password with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.OTP, c.ID); err != nil {
return fmt.Errorf("error updating one time password with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
@ -134,7 +175,7 @@ func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, t
return fmt.Errorf("error selecting TOTP configurations: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations))
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
for _, c := range configs {
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
@ -231,6 +272,77 @@ func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
}
}
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
var count int
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
return err
}
if count == 0 {
return nil
}
configs := make([]encEncryption, 0, count)
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return fmt.Errorf("error selecting encyption value: %w", err)
}
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
for _, c := range configs {
if c.Value, err = provider.decrypt(c.Value); err != nil {
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
}
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
}
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
}
}
return nil
}
func schemaEncryptionCheckKeyOneTimePassword(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time passwords: %w", err)}
}
var config encOneTimePassword
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time password to struct: %w", err)}
}
if _, err = provider.decrypt(config.OTP); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableOneTimePassword, result
}
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
@ -326,12 +438,77 @@ func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
}
}
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
var (
rows *sqlx.Rows
err error
)
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
}
var config encEncryption
for rows.Next() {
result.Total++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
}
if _, err = provider.decrypt(config.Value); err != nil {
result.Invalid++
}
}
_ = rows.Close()
return tableEncryption, result
}
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
return utils.Encrypt(clearText, &p.key)
return utils.Encrypt(clearText, &p.keys.encryption)
}
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
return utils.Decrypt(cipherText, &p.key)
return utils.Decrypt(cipherText, &p.keys.encryption)
}
func (p *SQLProvider) hmacSignature(values ...[]byte) string {
h := hmac.New(sha512.New, p.keys.signature)
for i := 0; i < len(values); i++ {
h.Write(values[i])
}
return fmt.Sprintf("%x", h.Sum(nil))
}
func (p *SQLProvider) getKeySigHMAC(ctx context.Context) (key []byte, err error) {
if key, err = p.getEncryptionValue(ctx, "hmac_signature_key"); err != nil {
if errors.Is(err, sql.ErrNoRows) {
key = make([]byte, sha512.BlockSize)
_, err = rand.Read(key)
if err != nil {
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
}
if err = p.setEncryptionValue(ctx, "hmac_signature_key", key); err != nil {
return nil, err
}
return key, nil
}
return nil, err
}
return key, nil
}
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
@ -345,6 +522,18 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
return p.decrypt(encryptedValue)
}
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
if value, err = p.encrypt(value); err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
return err
}
return nil
}
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
valueClearText, err := uuid.NewRandom()
if err != nil {

View File

@ -71,6 +71,36 @@ const (
WHERE jti = ?;`
)
const (
queryFmtSelectOTP = `
SELECT id, public_id, signature, iat, issued_ip, exp, username intent, consumed, consumed_ip, revoked, revoked_ip, password
FROM %s
WHERE signature = ? AND username = ?;`
queryFmtInsertOTP = `
INSERT INTO %s (public_id, signature, iat, issued_ip, exp, username, intent, password)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtConsumeOTP = `
UPDATE %s
SET consumed = ?, consumed_ip = ?
WHERE signature = ?;`
queryFmtRevokeOTP = `
UPDATE %s
SET revoked = CURRENT_TIMESTAMP, revoked_ip = ?
WHERE public_id = ?;`
queryFmtSelectOTPEncryptedData = `
SELECT id, password
FROM %s;`
queryFmtUpdateOTPEncryptedData = `
UPDATE %s
SET password = ?
WHERE id = ?;`
)
const (
queryFmtSelectTOTPConfiguration = `
SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret
@ -87,8 +117,7 @@ const (
SELECT id, secret
FROM %s;`
//nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials.
queryFmtUpdateTOTPConfigurationSecret = `
queryFmtUpdateTOTPConfigurationEncryptedData = `
UPDATE %s
SET secret = ?
WHERE id = ?;`
@ -241,6 +270,14 @@ const (
VALUES ($1, $2)
ON CONFLICT (name)
DO UPDATE SET value = $2;`
queryFmtSelectEncryptionEncryptedData = `
SELECT id, value
FROM %s;`
queryFmtUpdateEncryptionEncryptedData = `
UPDATE %s
SET value = ?`
)
const (

View File

@ -251,7 +251,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.keys.encryption); err != nil {
return err
}
}

View File

@ -32,14 +32,24 @@ type encOAuth2Session struct {
Session []byte `db:"session_data"`
}
type encOneTimePassword struct {
ID int `db:"id"`
OTP []byte `db:"otp"`
}
type encTOTPConfiguration struct {
ID int `db:"id"`
Secret []byte `db:"secret"`
}
type encWebAuthnDevice struct {
ID int `db:"id"`
PublicKey []byte `db:"public_key"`
}
type encTOTPConfiguration struct {
ID int `db:"id" json:"-"`
Secret []byte `db:"secret" json:"-"`
type encEncryption struct {
ID int `db:"id"`
Value []byte `db:"value"`
}
// EncryptionValidationResult contains information about the success of a schema encryption validation.