feat(web): one time password verification
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>feat-otp-verification
parent
91083f0052
commit
fe5e07d868
|
@ -39,3 +39,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
|
|||
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
|
||||
| 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table |
|
||||
| 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes |
|
||||
| 11 | 4.38.0 | One-Time Password for Identity Verification via Email Changes |
|
||||
|
|
|
@ -24,4 +24,5 @@ type Configuration struct {
|
|||
WebAuthn WebAuthnConfiguration `koanf:"webauthn"`
|
||||
PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"`
|
||||
PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"`
|
||||
IdentityValidation IdentityValidation `koanf:"identity_validation"`
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type IdentityValidation struct {
|
||||
ResetPassword ResetPasswordIdentityValidation `koanf:"reset_password"`
|
||||
CredentialRegistration CredentialRegistrationIdentityValidation `koanf:"credential_registration"`
|
||||
}
|
||||
|
||||
type ResetPasswordIdentityValidation struct {
|
||||
EmailExpiration time.Duration `koanf:"email_expiration"`
|
||||
}
|
||||
|
||||
type CredentialRegistrationIdentityValidation struct {
|
||||
EmailExpiration time.Duration `koanf:"email_expiration"`
|
||||
ElevationExpiration time.Duration `koanf:"elevation_expiration"`
|
||||
OTPCharacters int `koanf:"otp_characters"`
|
||||
Skip2FA bool `koanf:"skip_2fa"`
|
||||
}
|
|
@ -446,7 +446,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
|
|||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
"username": "test",
|
||||
"password": "hello",
|
||||
"requestMethod": "GET",
|
||||
"requestMethod": fasthttp.MethodGet,
|
||||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves.
|
||||
type OTPEscalationProtectedEndpointConfig struct {
|
||||
Characters int
|
||||
EmailValidityDuration time.Duration
|
||||
EscalationValidityDuration time.Duration
|
||||
Skip2FA bool
|
||||
}
|
||||
|
||||
type RequiredLevelProtectedEndpointConfig struct {
|
||||
Level authentication.Level
|
||||
}
|
||||
|
||||
type ProtectedEndpointConfig struct {
|
||||
OTPEscalation *OTPEscalationProtectedEndpointConfig
|
||||
RequiredLevel *RequiredLevelProtectedEndpointConfig
|
||||
}
|
||||
|
||||
func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware {
|
||||
var handlers []ProtectedEndpointHandler
|
||||
|
||||
if config.RequiredLevel != nil {
|
||||
handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level})
|
||||
}
|
||||
|
||||
if config.OTPEscalation != nil {
|
||||
handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation})
|
||||
}
|
||||
|
||||
return ProtectedEndpoint(handlers...)
|
||||
}
|
||||
|
||||
func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware {
|
||||
n := len(handlers)
|
||||
|
||||
return func(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
session, err := ctx.GetSession()
|
||||
|
||||
if err != nil || session.IsAnonymous() {
|
||||
ctx.SetAuthenticationErrorJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var failed, failedAuthentication, failedElevation bool
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if handlers[i].Check(ctx, &session) {
|
||||
continue
|
||||
}
|
||||
|
||||
failed = true
|
||||
|
||||
if handlers[i].IsAuthentication() {
|
||||
failedAuthentication = true
|
||||
}
|
||||
|
||||
if handlers[i].IsElevation() {
|
||||
failedElevation = true
|
||||
}
|
||||
|
||||
handlers[i].Failure(ctx, &session)
|
||||
}
|
||||
|
||||
if failed {
|
||||
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ProtectedEndpointHandler interface {
|
||||
Name() string
|
||||
Check(ctx *AutheliaCtx, s *session.UserSession) (success bool)
|
||||
Failure(ctx *AutheliaCtx, s *session.UserSession)
|
||||
|
||||
IsAuthentication() bool
|
||||
IsElevation() bool
|
||||
}
|
||||
|
||||
func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler {
|
||||
handler := &RequiredLevelProtectedEndpointHandler{
|
||||
level: level,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
|
||||
if handler.statusCode == 0 {
|
||||
handler.statusCode = fasthttp.StatusForbidden
|
||||
}
|
||||
|
||||
if handler.level == 0 {
|
||||
handler.level = authentication.OneFactor
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
type RequiredLevelProtectedEndpointHandler struct {
|
||||
level authentication.Level
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (h *RequiredLevelProtectedEndpointHandler) Name() string {
|
||||
return fmt.Sprintf("required_level(%s)", h.level)
|
||||
}
|
||||
|
||||
func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||||
return s.AuthenticationLevel >= h.level
|
||||
}
|
||||
|
||||
func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) {
|
||||
}
|
||||
|
||||
func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler {
|
||||
return &OTPEscalationProtectedEndpointHandler{
|
||||
config: &config,
|
||||
}
|
||||
}
|
||||
|
||||
type OTPEscalationProtectedEndpointHandler struct {
|
||||
config *OTPEscalationProtectedEndpointConfig
|
||||
}
|
||||
|
||||
func (h *OTPEscalationProtectedEndpointHandler) Name() string {
|
||||
return "one_time_password"
|
||||
}
|
||||
|
||||
func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||||
if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor {
|
||||
ctx.Logger.
|
||||
WithField("username", s.Username).
|
||||
Warning("User elevated session check has skipped due to 2FA")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if s.Elevations.User == nil {
|
||||
ctx.Logger.
|
||||
WithField("username", s.Username).
|
||||
Warning("User elevated session has not been created")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if s.Elevations.User.Expires.Before(ctx.Clock.Now()) {
|
||||
ctx.Logger.
|
||||
WithField("username", s.Username).
|
||||
WithField("expires", s.Elevations.User.Expires).
|
||||
Debug("User elevated session IP did not match the request")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) {
|
||||
ctx.Logger.
|
||||
WithField("username", s.Username).
|
||||
WithField("elevation_ip", s.Elevations.User.RemoteIP).
|
||||
Warning("User elevated session IP did not match the request")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) {
|
||||
if s.Elevations.User != nil {
|
||||
// If we make it here we should destroy the elevation data.
|
||||
s.Elevations.User = nil
|
||||
|
||||
if err := ctx.SaveSession(*s); err != nil {
|
||||
ctx.Logger.WithError(err).Error("Error session after user elevated session failure")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
|
||||
func Require1FA(next RequestHandler) RequestHandler {
|
||||
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden))
|
||||
|
||||
return handler(next)
|
||||
}
|
||||
|
||||
// Require2FA requires the user to have authenticated with two-factor authentication.
|
||||
func Require2FA(next RequestHandler) RequestHandler {
|
||||
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden))
|
||||
|
||||
return handler(next)
|
||||
}
|
||||
|
||||
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
|
||||
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
session, err := ctx.GetSession()
|
||||
|
||||
if err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
||||
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
package middlewares_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
func handleSetStatus(code int) middlewares.RequestHandler {
|
||||
return func(ctx *middlewares.AutheliaCtx) {
|
||||
if err := ctx.ReplyJSON(middlewares.ErrorResponse{Status: "OK", Message: "Endpoint Response"}, 0); err != nil {
|
||||
ctx.Logger.Error(err)
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(code)
|
||||
ctx.Response.Header.Set("X-Testing-Success", "yes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedEndpointRequiredLevel(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
level authentication.Level
|
||||
have session.UserSession
|
||||
expected, status int
|
||||
expectedAuth, expectedElevate bool
|
||||
}{
|
||||
{
|
||||
name: "1FAWithAuthenticatedUser2FAShould200OK",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusOK,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1FAWithAuthenticatedUser1FAShould200OK",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusOK,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.OneFactor,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1FAWithAuthenticatedUser2FAShould301Found",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusFound,
|
||||
status: fasthttp.StatusFound,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1FAWithAuthenticatedUser1FAShould301Found",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusFound,
|
||||
status: fasthttp.StatusFound,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.OneFactor,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusUnauthorized,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
AuthenticationLevel: authentication.NotAuthenticated,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||
level: authentication.OneFactor,
|
||||
expected: fasthttp.StatusUnauthorized,
|
||||
status: fasthttp.StatusFound,
|
||||
have: session.UserSession{
|
||||
AuthenticationLevel: authentication.NotAuthenticated,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||
level: authentication.TwoFactor,
|
||||
expected: fasthttp.StatusForbidden,
|
||||
expectedAuth: true,
|
||||
status: fasthttp.StatusFound,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.OneFactor,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
defer mock.Close()
|
||||
|
||||
err := mock.Ctx.SaveSession(tc.have)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
var h middlewares.RequestHandler
|
||||
|
||||
switch tc.level {
|
||||
case authentication.OneFactor:
|
||||
h = middlewares.Require1FA(handleSetStatus(tc.status))
|
||||
default:
|
||||
h = middlewares.Require2FA(handleSetStatus(tc.status))
|
||||
}
|
||||
|
||||
h(mock.Ctx)
|
||||
|
||||
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||
|
||||
if tc.expected == tc.status {
|
||||
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
} else {
|
||||
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":%t,"elevation":%t}`, fasthttp.StatusMessage(tc.expected), tc.expectedAuth, tc.expectedElevate), string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedEndpointOTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
characters int
|
||||
emailexp, sessionexp time.Duration
|
||||
skip2fa bool
|
||||
have session.UserSession
|
||||
ip net.IP
|
||||
time time.Time
|
||||
expected, status int
|
||||
}{
|
||||
{
|
||||
name: "ReturnUnauthorizedForAnonymous",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: false,
|
||||
expected: fasthttp.StatusUnauthorized,
|
||||
status: fasthttp.StatusFound,
|
||||
have: session.UserSession{
|
||||
AuthenticationLevel: authentication.NotAuthenticated,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Return200OKWhen2FASkipAndUserIs2FAd",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: true,
|
||||
expected: fasthttp.StatusOK,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HandleEscalationEmailWhen2FASkipAndUserIs1FAd",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: true,
|
||||
expected: fasthttp.StatusForbidden,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.OneFactor,
|
||||
Elevations: session.Elevations{User: nil},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HandleEscalationEmailWhenUserIs2FAd",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: false,
|
||||
expected: fasthttp.StatusForbidden,
|
||||
status: fasthttp.StatusOK,
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
Elevations: session.Elevations{User: nil},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Return200OKWhenUserIsEscalated",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: false,
|
||||
expected: fasthttp.StatusOK,
|
||||
status: fasthttp.StatusOK,
|
||||
ip: net.ParseIP("192.168.0.1"),
|
||||
time: time.Unix(1671322337, 0),
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
Elevations: session.Elevations{
|
||||
User: &session.Elevation{
|
||||
RemoteIP: net.ParseIP("192.168.0.1"),
|
||||
Expires: time.Unix(1671322347, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Return403ForbiddenWhenUserIsEscalatedButInvalidIP",
|
||||
characters: 10,
|
||||
emailexp: time.Minute,
|
||||
sessionexp: time.Minute,
|
||||
skip2fa: false,
|
||||
expected: fasthttp.StatusForbidden,
|
||||
status: fasthttp.StatusOK,
|
||||
ip: net.ParseIP("192.168.0.2"),
|
||||
time: time.Unix(1671322337, 0),
|
||||
have: session.UserSession{
|
||||
Username: "john",
|
||||
DisplayName: "John Wick",
|
||||
Emails: []string{"john.wick@notmessingaround.com"},
|
||||
AuthenticationLevel: authentication.TwoFactor,
|
||||
Elevations: session.Elevations{
|
||||
User: &session.Elevation{
|
||||
RemoteIP: net.ParseIP("192.168.0.1"),
|
||||
Expires: time.Unix(1671322347, 0),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
defer mock.Close()
|
||||
|
||||
mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, tc.ip.String())
|
||||
|
||||
err := mock.Ctx.SaveSession(tc.have)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
if !tc.time.IsZero() {
|
||||
mock.Clock.Set(tc.time)
|
||||
mock.Ctx.Clock = &mock.Clock
|
||||
}
|
||||
|
||||
h := middlewares.ProtectedEndpoint(middlewares.NewOTPEscalationProtectedEndpointHandler(middlewares.OTPEscalationProtectedEndpointConfig{
|
||||
Characters: tc.characters,
|
||||
EmailValidityDuration: tc.emailexp,
|
||||
EscalationValidityDuration: tc.sessionexp,
|
||||
Skip2FA: tc.skip2fa,
|
||||
}))(handleSetStatus(tc.status))
|
||||
|
||||
h(mock.Ctx)
|
||||
|
||||
switch {
|
||||
case tc.have.IsAnonymous():
|
||||
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":false,"elevation":false}`, fasthttp.StatusMessage(tc.expected)), string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
case tc.skip2fa && tc.have.AuthenticationLevel == authentication.TwoFactor:
|
||||
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
case tc.have.Elevations.User == nil || mock.Ctx.Clock.Now().After(tc.have.Elevations.User.Expires) || !tc.ip.Equal(tc.have.Elevations.User.RemoteIP):
|
||||
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, `{"status":"KO","message":"Forbidden","authentication":false,"elevation":true}`, string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
default:
|
||||
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
)
|
||||
|
||||
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
|
||||
func Require1FA(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.OneFactor {
|
||||
ctx.ReplyForbidden()
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Require2FA requires the user to have authenticated with two-factor authentication.
|
||||
func Require2FA(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
||||
ctx.ReplyForbidden()
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
|
||||
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
||||
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
|
@ -110,6 +110,20 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// ConsumeOneTimePassword mocks base method.
|
||||
func (m *MockStorage) ConsumeOneTimePassword(arg0 context.Context, arg1 *model.OneTimePassword) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConsumeOneTimePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ConsumeOneTimePassword indicates an expected call of ConsumeOneTimePassword.
|
||||
func (mr *MockStorageMockRecorder) ConsumeOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).ConsumeOneTimePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeactivateOAuth2Session mocks base method.
|
||||
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -299,6 +313,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// LoadOneTimePassword mocks base method.
|
||||
func (m *MockStorage) LoadOneTimePassword(arg0 context.Context, arg1, arg2, arg3 string) (*model.OneTimePassword, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadOneTimePassword", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(*model.OneTimePassword)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadOneTimePassword indicates an expected call of LoadOneTimePassword.
|
||||
func (mr *MockStorageMockRecorder) LoadOneTimePassword(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOneTimePassword", reflect.TypeOf((*MockStorage)(nil).LoadOneTimePassword), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod mocks base method.
|
||||
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -521,6 +550,20 @@ func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, ar
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// RevokeOneTimePassword mocks base method.
|
||||
func (m *MockStorage) RevokeOneTimePassword(arg0 context.Context, arg1 uuid.UUID, arg2 model.IP) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevokeOneTimePassword", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RevokeOneTimePassword indicates an expected call of RevokeOneTimePassword.
|
||||
func (mr *MockStorageMockRecorder) RevokeOneTimePassword(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).RevokeOneTimePassword), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Rollback mocks base method.
|
||||
func (m *MockStorage) Rollback(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -662,6 +705,21 @@ func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SaveOneTimePassword mocks base method.
|
||||
func (m *MockStorage) SaveOneTimePassword(arg0 context.Context, arg1 model.OneTimePassword) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveOneTimePassword", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SaveOneTimePassword indicates an expected call of SaveOneTimePassword.
|
||||
func (mr *MockStorageMockRecorder) SaveOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOneTimePassword", reflect.TypeOf((*MockStorage)(nil).SaveOneTimePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod mocks base method.
|
||||
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
OTPIntentElevateUserSession = "eus"
|
||||
)
|
||||
|
||||
// NewOneTimePassword returns a new OneTimePassword.
|
||||
func NewOneTimePassword(publicID uuid.UUID, username, intent string, iat, exp time.Time, ip net.IP, value []byte) (otp OneTimePassword) {
|
||||
return OneTimePassword{
|
||||
PublicID: publicID,
|
||||
IssuedAt: iat,
|
||||
ExpiresAt: exp,
|
||||
Username: username,
|
||||
Intent: intent,
|
||||
IssuedIP: NewIP(ip),
|
||||
Password: value,
|
||||
}
|
||||
}
|
||||
|
||||
// OneTimePassword represents special one time passwords stored in the database.
|
||||
type OneTimePassword struct {
|
||||
ID int `db:"id"`
|
||||
PublicID uuid.UUID `db:"public_id"`
|
||||
Signature string `db:"signature"`
|
||||
IssuedAt time.Time `db:"iat"`
|
||||
IssuedIP IP `db:"issued_ip"`
|
||||
ExpiresAt time.Time `db:"exp"`
|
||||
Username string `db:"username"`
|
||||
Intent string `db:"intent"`
|
||||
Consumed sql.NullTime `db:"consumed"`
|
||||
ConsumedIP NullIP `db:"consumed_ip"`
|
||||
Revoked sql.NullTime `db:"revoked"`
|
||||
RevokedIP NullIP `db:"revoked_ip"`
|
||||
Password []byte `db:"password"`
|
||||
}
|
|
@ -176,7 +176,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
|||
|
||||
middleware2FA := middlewares.NewBridgeBuilder(*config, providers).
|
||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||
WithPostMiddlewares(middlewares.Require2FAWithAPIResponse).
|
||||
WithPostMiddlewares(middlewares.Require2FA).
|
||||
Build()
|
||||
|
||||
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
@ -138,7 +139,9 @@ type TLSServerContext struct {
|
|||
func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) {
|
||||
serverContext = new(TLSServerContext)
|
||||
|
||||
providers := middlewares.Providers{}
|
||||
providers := middlewares.Providers{
|
||||
Random: random.NewMathematical(),
|
||||
}
|
||||
|
||||
providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath})
|
||||
if err != nil {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/fasthttp/session/v2"
|
||||
|
@ -44,6 +45,8 @@ type UserSession struct {
|
|||
PasswordResetUsername *string
|
||||
|
||||
RefreshTTL time.Time
|
||||
|
||||
Elevations Elevations
|
||||
}
|
||||
|
||||
// TOTP holds the TOTP registration session data.
|
||||
|
@ -68,3 +71,13 @@ type Identity struct {
|
|||
Email string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
type Elevations struct {
|
||||
User *Elevation
|
||||
}
|
||||
|
||||
type Elevation struct {
|
||||
ID int
|
||||
RemoteIP net.IP
|
||||
Expires time.Time
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ const (
|
|||
tableAuthenticationLogs = "authentication_logs"
|
||||
tableDuoDevices = "duo_devices"
|
||||
tableIdentityVerification = "identity_verification"
|
||||
tableOneTimePassword = "one_time_password"
|
||||
tableTOTPConfigurations = "totp_configurations"
|
||||
tableUserOpaqueIdentifier = "user_opaque_identifier"
|
||||
tableUserPreferences = "user_preferences"
|
||||
|
|
|
@ -419,7 +419,7 @@ CREATE TABLE IF NOT EXISTS oauth2_consent_preconfiguration (
|
|||
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
scopes TEXT NOT NULL,
|
||||
audience TEXT NULL,
|
||||
CONSTRAINT "oauth2_consent_preconfiguration_subject_fkey"
|
||||
CONSTRAINT oauth2_consent_preconfiguration_subject_fkey
|
||||
FOREIGN KEY (subject)
|
||||
REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT
|
||||
);
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
DROP TABLE IF EXISTS one_time_password;
|
||||
DROP TABLE IF EXISTS user_elevated_session;
|
|
@ -0,0 +1,30 @@
|
|||
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||
public_id CHAR(36) NOT NULL,
|
||||
signature VARCHAR(128) NOT NULL,
|
||||
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp TIMESTAMP NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
intent VARCHAR(100) NOT NULL,
|
||||
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
revoked TIMESTAMP NULL DEFAULT NULL,
|
||||
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
password BLOB NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
|
||||
|
||||
CREATE UNIQUE INDEX one_time_password_signature ON one_time_password (signature);
|
||||
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_ip VARCHAR(39) NOT NULL,
|
||||
method VARCHAR(10) NOT NULL,
|
||||
method_id INTEGER NULL,
|
||||
expires TIMESTAMP NOT NULL,
|
||||
username VARCHAR(100) NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
|
||||
|
||||
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -0,0 +1,30 @@
|
|||
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||
id SERIAL CONSTRAINT one_time_password_pkey PRIMARY KEY,
|
||||
public_id CHAR(36) NOT NULL,
|
||||
signature VARCHAR(128) NOT NULL,
|
||||
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
intent VARCHAR(100) NOT NULL,
|
||||
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
revoked TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
password BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
|
||||
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||
id SERIAL CONSTRAINT user_elevated_session_pkey PRIMARY KEY,
|
||||
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_ip VARCHAR(39) NOT NULL,
|
||||
method VARCHAR(10) NOT NULL,
|
||||
method_id INTEGER NULL,
|
||||
expires TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
username VARCHAR(100) NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -0,0 +1,30 @@
|
|||
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
public_id CHAR(36) NOT NULL,
|
||||
signature VARCHAR(128) NOT NULL,
|
||||
iat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp DATETIME NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
intent VARCHAR(100) NOT NULL,
|
||||
consumed DATETIME NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
revoked DATETIME NULL DEFAULT NULL,
|
||||
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
password BLOB NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
|
||||
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
created_ip VARCHAR(39) NOT NULL,
|
||||
method VARCHAR(10) NOT NULL,
|
||||
method_id INTEGER NULL,
|
||||
expires DATETIME NOT NULL,
|
||||
username VARCHAR(100) NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
const (
|
||||
// This is the latest schema version for the purpose of tests.
|
||||
LatestVersion = 10
|
||||
LatestVersion = 11
|
||||
)
|
||||
|
||||
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
||||
|
|
|
@ -32,6 +32,11 @@ type Provider interface {
|
|||
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
|
||||
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
||||
|
||||
SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error)
|
||||
ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error)
|
||||
RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error)
|
||||
LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error)
|
||||
|
||||
SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error)
|
||||
UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error)
|
||||
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
||||
|
|
|
@ -29,7 +29,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
driverName: driverName,
|
||||
config: config,
|
||||
errOpen: err,
|
||||
log: logging.Logger(),
|
||||
|
||||
keys: SQLProviderKeys{
|
||||
encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
|
||||
},
|
||||
|
||||
log: logging.Logger(),
|
||||
|
||||
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
|
||||
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
||||
|
@ -38,6 +43,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
||||
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
|
||||
|
||||
sqlInsertOneTimePassword: fmt.Sprintf(queryFmtInsertOTP, tableOneTimePassword),
|
||||
sqlConsumeOneTimePassword: fmt.Sprintf(queryFmtConsumeOTP, tableOneTimePassword),
|
||||
sqlRevokeOneTimePassword: fmt.Sprintf(queryFmtRevokeOTP, tableOneTimePassword),
|
||||
sqlSelectOneTimePassword: fmt.Sprintf(queryFmtSelectOTP, tableOneTimePassword),
|
||||
|
||||
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
||||
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
|
||||
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
|
||||
|
@ -139,14 +149,17 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
|
||||
// SQLProvider is a storage provider persisting data in a SQL database.
|
||||
type SQLProvider struct {
|
||||
db *sqlx.DB
|
||||
key [32]byte
|
||||
db *sqlx.DB
|
||||
key [32]byte
|
||||
|
||||
name string
|
||||
driverName string
|
||||
schema string
|
||||
config *schema.Configuration
|
||||
errOpen error
|
||||
|
||||
keys SQLProviderKeys
|
||||
|
||||
log *logrus.Logger
|
||||
|
||||
// Table: authentication_logs.
|
||||
|
@ -158,6 +171,12 @@ type SQLProvider struct {
|
|||
sqlConsumeIdentityVerification string
|
||||
sqlSelectIdentityVerification string
|
||||
|
||||
// Table: one_time_password.
|
||||
sqlInsertOneTimePassword string
|
||||
sqlConsumeOneTimePassword string
|
||||
sqlRevokeOneTimePassword string
|
||||
sqlSelectOneTimePassword string
|
||||
|
||||
// Table: totp_configurations.
|
||||
sqlUpsertTOTPConfig string
|
||||
sqlDeleteTOTPConfig string
|
||||
|
@ -274,6 +293,12 @@ type SQLProvider struct {
|
|||
sqlFmtRenameTable string
|
||||
}
|
||||
|
||||
// SQLProviderKeys are the cryptography keys used by a SQLProvider.
|
||||
type SQLProviderKeys struct {
|
||||
encryption [32]byte
|
||||
signature []byte
|
||||
}
|
||||
|
||||
// Close the underlying database connection.
|
||||
func (p *SQLProvider) Close() (err error) {
|
||||
return p.db.Close()
|
||||
|
@ -313,14 +338,19 @@ func (p *SQLProvider) StartupCheck() (err error) {
|
|||
}
|
||||
|
||||
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
|
||||
case nil:
|
||||
break
|
||||
case ErrSchemaAlreadyUpToDate:
|
||||
p.log.Infof("Storage schema is already up to date")
|
||||
return nil
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("error during schema migrate: %w", err)
|
||||
}
|
||||
|
||||
if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil {
|
||||
return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeginTX begins a transaction.
|
||||
|
@ -816,6 +846,62 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
|
|||
}
|
||||
}
|
||||
|
||||
// SaveOneTimePassword saves a one time password to the database after generating the signature.
|
||||
func (p *SQLProvider) SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) {
|
||||
signature = p.hmacSignature([]byte(otp.Username), []byte(otp.Intent), otp.Password)
|
||||
|
||||
if otp.Password, err = p.encrypt(otp.Password); err != nil {
|
||||
return "", fmt.Errorf("error encrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimePassword,
|
||||
otp.PublicID, signature, otp.IssuedAt, otp.IssuedIP, otp.ExpiresAt,
|
||||
otp.Username, otp.Intent, otp.Password); err != nil {
|
||||
return "", fmt.Errorf("error inserting one time password for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// ConsumeOneTimePassword consumes a one time password using the signature.
|
||||
func (p *SQLProvider) ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimePassword, otp.Consumed, otp.ConsumedIP, otp.Signature); err != nil {
|
||||
return fmt.Errorf("error updating one time password (consume): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeOneTimePassword revokes a one time password using the public ID.
|
||||
func (p *SQLProvider) RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimePassword, ip, publicID); err != nil {
|
||||
return fmt.Errorf("error updating one time password (revoke): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadOneTimePassword loads a one time password from the database given a username, intent, and password.
|
||||
func (p *SQLProvider) LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) {
|
||||
otp = &model.OneTimePassword{}
|
||||
|
||||
signature := p.hmacSignature([]byte(username), []byte(intent), []byte(password))
|
||||
|
||||
if err = p.db.GetContext(ctx, otp, p.sqlSelectOneTimePassword, signature, username); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error selecting one time password: %w", err)
|
||||
}
|
||||
|
||||
if otp.Password, err = p.decrypt(otp.Password); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||
}
|
||||
|
||||
return otp, nil
|
||||
}
|
||||
|
||||
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
|
||||
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
||||
if config.Secret, err = p.encrypt(config.Secret); err != nil {
|
||||
|
|
|
@ -52,6 +52,11 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
|
|||
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
||||
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
|
||||
|
||||
provider.sqlInsertOneTimePassword = provider.db.Rebind(provider.sqlInsertOneTimePassword)
|
||||
provider.sqlConsumeOneTimePassword = provider.db.Rebind(provider.sqlConsumeOneTimePassword)
|
||||
provider.sqlRevokeOneTimePassword = provider.db.Rebind(provider.sqlRevokeOneTimePassword)
|
||||
provider.sqlSelectOneTimePassword = provider.db.Rebind(provider.sqlSelectOneTimePassword)
|
||||
|
||||
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
||||
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
|
||||
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
|
||||
|
|
|
@ -3,7 +3,10 @@ package storage
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -16,10 +19,10 @@ import (
|
|||
|
||||
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
|
||||
// by this command to encrypt the values again and update them using a transaction.
|
||||
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) {
|
||||
skey := sha256.Sum256([]byte(key))
|
||||
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
|
||||
key := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
if bytes.Equal(skey[:], p.key[:]) {
|
||||
if bytes.Equal(key[:], p.keys.encryption[:]) {
|
||||
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
|
||||
}
|
||||
|
||||
|
@ -33,6 +36,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
|||
}
|
||||
|
||||
encChangeFuncs := []EncryptionChangeKeyFunc{
|
||||
schemaEncryptionChangeKeyOneTimePassword,
|
||||
schemaEncryptionChangeKeyTOTP,
|
||||
schemaEncryptionChangeKeyWebAuthn,
|
||||
}
|
||||
|
@ -47,8 +51,10 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
|||
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
|
||||
}
|
||||
|
||||
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
|
||||
|
||||
for _, encChangeFunc := range encChangeFuncs {
|
||||
if err = encChangeFunc(ctx, p, tx, skey); err != nil {
|
||||
if err = encChangeFunc(ctx, p, tx, key); err != nil {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
||||
}
|
||||
|
@ -57,14 +63,6 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
|||
}
|
||||
}
|
||||
|
||||
if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
|
||||
if rerr := tx.Rollback(); rerr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
|
@ -89,6 +87,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
|
||||
if verbose {
|
||||
encCheckFuncs := []EncryptionCheckKeyFunc{
|
||||
schemaEncryptionCheckKeyOneTimePassword,
|
||||
schemaEncryptionCheckKeyTOTP,
|
||||
schemaEncryptionCheckKeyWebAuthn,
|
||||
}
|
||||
|
@ -103,6 +102,8 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
|
||||
}
|
||||
|
||||
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
|
||||
|
||||
for _, encCheckFunc := range encCheckFuncs {
|
||||
table, tableResult := encCheckFunc(ctx, p)
|
||||
|
||||
|
@ -113,6 +114,46 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func schemaEncryptionChangeKeyOneTimePassword(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||
var count int
|
||||
|
||||
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimePassword)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
configs := make([]encOneTimePassword, 0, count)
|
||||
|
||||
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("error selecting one time passwords: %w", err)
|
||||
}
|
||||
|
||||
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTPEncryptedData, tableOneTimePassword))
|
||||
|
||||
for _, c := range configs {
|
||||
if c.OTP, err = provider.decrypt(c.OTP); err != nil {
|
||||
return fmt.Errorf("error decrypting one time password with id '%d': %w", c.ID, err)
|
||||
}
|
||||
|
||||
if c.OTP, err = utils.Encrypt(c.OTP, &key); err != nil {
|
||||
return fmt.Errorf("error encrypting one time password with id '%d': %w", c.ID, err)
|
||||
}
|
||||
|
||||
if _, err = tx.ExecContext(ctx, query, c.OTP, c.ID); err != nil {
|
||||
return fmt.Errorf("error updating one time password with id '%d': %w", c.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||
var count int
|
||||
|
||||
|
@ -134,7 +175,7 @@ func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, t
|
|||
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
||||
}
|
||||
|
||||
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations))
|
||||
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
|
||||
|
||||
for _, c := range configs {
|
||||
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
|
||||
|
@ -231,6 +272,77 @@ func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
|
|||
}
|
||||
}
|
||||
|
||||
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||
var count int
|
||||
|
||||
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
configs := make([]encEncryption, 0, count)
|
||||
|
||||
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("error selecting encyption value: %w", err)
|
||||
}
|
||||
|
||||
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
|
||||
|
||||
for _, c := range configs {
|
||||
if c.Value, err = provider.decrypt(c.Value); err != nil {
|
||||
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
|
||||
}
|
||||
|
||||
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
|
||||
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
|
||||
}
|
||||
|
||||
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
|
||||
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func schemaEncryptionCheckKeyOneTimePassword(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||
var (
|
||||
rows *sqlx.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
|
||||
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time passwords: %w", err)}
|
||||
}
|
||||
|
||||
var config encOneTimePassword
|
||||
|
||||
for rows.Next() {
|
||||
result.Total++
|
||||
|
||||
if err = rows.StructScan(&config); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time password to struct: %w", err)}
|
||||
}
|
||||
|
||||
if _, err = provider.decrypt(config.OTP); err != nil {
|
||||
result.Invalid++
|
||||
}
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
|
||||
return tableOneTimePassword, result
|
||||
}
|
||||
|
||||
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||
var (
|
||||
rows *sqlx.Rows
|
||||
|
@ -326,12 +438,77 @@ func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
|
|||
}
|
||||
}
|
||||
|
||||
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||
var (
|
||||
rows *sqlx.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
||||
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
|
||||
}
|
||||
|
||||
var config encEncryption
|
||||
|
||||
for rows.Next() {
|
||||
result.Total++
|
||||
|
||||
if err = rows.StructScan(&config); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
|
||||
}
|
||||
|
||||
if _, err = provider.decrypt(config.Value); err != nil {
|
||||
result.Invalid++
|
||||
}
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
|
||||
return tableEncryption, result
|
||||
}
|
||||
|
||||
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||
return utils.Encrypt(clearText, &p.key)
|
||||
return utils.Encrypt(clearText, &p.keys.encryption)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
||||
return utils.Decrypt(cipherText, &p.key)
|
||||
return utils.Decrypt(cipherText, &p.keys.encryption)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) hmacSignature(values ...[]byte) string {
|
||||
h := hmac.New(sha512.New, p.keys.signature)
|
||||
|
||||
for i := 0; i < len(values); i++ {
|
||||
h.Write(values[i])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
func (p *SQLProvider) getKeySigHMAC(ctx context.Context) (key []byte, err error) {
|
||||
if key, err = p.getEncryptionValue(ctx, "hmac_signature_key"); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
key = make([]byte, sha512.BlockSize)
|
||||
|
||||
_, err = rand.Read(key)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
|
||||
}
|
||||
|
||||
if err = p.setEncryptionValue(ctx, "hmac_signature_key", key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
|
||||
|
@ -345,6 +522,18 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
|
|||
return p.decrypt(encryptedValue)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
|
||||
if value, err = p.encrypt(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
|
||||
valueClearText, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
|
|
|
@ -71,6 +71,36 @@ const (
|
|||
WHERE jti = ?;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectOTP = `
|
||||
SELECT id, public_id, signature, iat, issued_ip, exp, username intent, consumed, consumed_ip, revoked, revoked_ip, password
|
||||
FROM %s
|
||||
WHERE signature = ? AND username = ?;`
|
||||
|
||||
queryFmtInsertOTP = `
|
||||
INSERT INTO %s (public_id, signature, iat, issued_ip, exp, username, intent, password)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
|
||||
|
||||
queryFmtConsumeOTP = `
|
||||
UPDATE %s
|
||||
SET consumed = ?, consumed_ip = ?
|
||||
WHERE signature = ?;`
|
||||
|
||||
queryFmtRevokeOTP = `
|
||||
UPDATE %s
|
||||
SET revoked = CURRENT_TIMESTAMP, revoked_ip = ?
|
||||
WHERE public_id = ?;`
|
||||
|
||||
queryFmtSelectOTPEncryptedData = `
|
||||
SELECT id, password
|
||||
FROM %s;`
|
||||
|
||||
queryFmtUpdateOTPEncryptedData = `
|
||||
UPDATE %s
|
||||
SET password = ?
|
||||
WHERE id = ?;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectTOTPConfiguration = `
|
||||
SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret
|
||||
|
@ -87,8 +117,7 @@ const (
|
|||
SELECT id, secret
|
||||
FROM %s;`
|
||||
|
||||
//nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials.
|
||||
queryFmtUpdateTOTPConfigurationSecret = `
|
||||
queryFmtUpdateTOTPConfigurationEncryptedData = `
|
||||
UPDATE %s
|
||||
SET secret = ?
|
||||
WHERE id = ?;`
|
||||
|
@ -241,6 +270,14 @@ const (
|
|||
VALUES ($1, $2)
|
||||
ON CONFLICT (name)
|
||||
DO UPDATE SET value = $2;`
|
||||
|
||||
queryFmtSelectEncryptionEncryptedData = `
|
||||
SELECT id, value
|
||||
FROM %s;`
|
||||
|
||||
queryFmtUpdateEncryptionEncryptedData = `
|
||||
UPDATE %s
|
||||
SET value = ?`
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -251,7 +251,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
|
|||
|
||||
if migration.Version == 1 && migration.Up {
|
||||
// Add the schema encryption value if upgrading to v1.
|
||||
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
|
||||
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.keys.encryption); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,14 +32,24 @@ type encOAuth2Session struct {
|
|||
Session []byte `db:"session_data"`
|
||||
}
|
||||
|
||||
type encOneTimePassword struct {
|
||||
ID int `db:"id"`
|
||||
OTP []byte `db:"otp"`
|
||||
}
|
||||
|
||||
type encTOTPConfiguration struct {
|
||||
ID int `db:"id"`
|
||||
Secret []byte `db:"secret"`
|
||||
}
|
||||
|
||||
type encWebAuthnDevice struct {
|
||||
ID int `db:"id"`
|
||||
PublicKey []byte `db:"public_key"`
|
||||
}
|
||||
|
||||
type encTOTPConfiguration struct {
|
||||
ID int `db:"id" json:"-"`
|
||||
Secret []byte `db:"secret" json:"-"`
|
||||
type encEncryption struct {
|
||||
ID int `db:"id"`
|
||||
Value []byte `db:"value"`
|
||||
}
|
||||
|
||||
// EncryptionValidationResult contains information about the success of a schema encryption validation.
|
||||
|
|
Loading…
Reference in New Issue