feat(web): one time password verification
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>feat-otp-verification
parent
91083f0052
commit
fe5e07d868
|
@ -39,3 +39,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
|
||||||
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
|
| 8 | 4.38.0 | OpenID Connect 1.0 Pushed Authorization Requests |
|
||||||
| 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table |
|
| 9 | 4.38.0 | Fix a PostgreSQL NOT NULL constraint issue on the `aaguid` column of the `webauthn_devices` table |
|
||||||
| 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes |
|
| 10 | 4.38.0 | WebAuthn adjustments for multi-cookie domain changes |
|
||||||
|
| 11 | 4.38.0 | One-Time Password for Identity Verification via Email Changes |
|
||||||
|
|
|
@ -24,4 +24,5 @@ type Configuration struct {
|
||||||
WebAuthn WebAuthnConfiguration `koanf:"webauthn"`
|
WebAuthn WebAuthnConfiguration `koanf:"webauthn"`
|
||||||
PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"`
|
PasswordPolicy PasswordPolicyConfiguration `koanf:"password_policy"`
|
||||||
PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"`
|
PrivacyPolicy PrivacyPolicy `koanf:"privacy_policy"`
|
||||||
|
IdentityValidation IdentityValidation `koanf:"identity_validation"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IdentityValidation struct {
|
||||||
|
ResetPassword ResetPasswordIdentityValidation `koanf:"reset_password"`
|
||||||
|
CredentialRegistration CredentialRegistrationIdentityValidation `koanf:"credential_registration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResetPasswordIdentityValidation struct {
|
||||||
|
EmailExpiration time.Duration `koanf:"email_expiration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CredentialRegistrationIdentityValidation struct {
|
||||||
|
EmailExpiration time.Duration `koanf:"email_expiration"`
|
||||||
|
ElevationExpiration time.Duration `koanf:"elevation_expiration"`
|
||||||
|
OTPCharacters int `koanf:"otp_characters"`
|
||||||
|
Skip2FA bool `koanf:"skip_2fa"`
|
||||||
|
}
|
|
@ -446,7 +446,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
|
||||||
s.mock.Ctx.Request.SetBodyString(`{
|
s.mock.Ctx.Request.SetBodyString(`{
|
||||||
"username": "test",
|
"username": "test",
|
||||||
"password": "hello",
|
"password": "hello",
|
||||||
"requestMethod": "GET",
|
"requestMethod": fasthttp.MethodGet,
|
||||||
"keepMeLoggedIn": false
|
"keepMeLoggedIn": false
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,235 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/authentication"
|
||||||
|
"github.com/authelia/authelia/v4/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves.
|
||||||
|
type OTPEscalationProtectedEndpointConfig struct {
|
||||||
|
Characters int
|
||||||
|
EmailValidityDuration time.Duration
|
||||||
|
EscalationValidityDuration time.Duration
|
||||||
|
Skip2FA bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequiredLevelProtectedEndpointConfig struct {
|
||||||
|
Level authentication.Level
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtectedEndpointConfig struct {
|
||||||
|
OTPEscalation *OTPEscalationProtectedEndpointConfig
|
||||||
|
RequiredLevel *RequiredLevelProtectedEndpointConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware {
|
||||||
|
var handlers []ProtectedEndpointHandler
|
||||||
|
|
||||||
|
if config.RequiredLevel != nil {
|
||||||
|
handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level})
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.OTPEscalation != nil {
|
||||||
|
handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProtectedEndpoint(handlers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware {
|
||||||
|
n := len(handlers)
|
||||||
|
|
||||||
|
return func(next RequestHandler) RequestHandler {
|
||||||
|
return func(ctx *AutheliaCtx) {
|
||||||
|
session, err := ctx.GetSession()
|
||||||
|
|
||||||
|
if err != nil || session.IsAnonymous() {
|
||||||
|
ctx.SetAuthenticationErrorJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var failed, failedAuthentication, failedElevation bool
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if handlers[i].Check(ctx, &session) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
failed = true
|
||||||
|
|
||||||
|
if handlers[i].IsAuthentication() {
|
||||||
|
failedAuthentication = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if handlers[i].IsElevation() {
|
||||||
|
failedElevation = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handlers[i].Failure(ctx, &session)
|
||||||
|
}
|
||||||
|
|
||||||
|
if failed {
|
||||||
|
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtectedEndpointHandler interface {
|
||||||
|
Name() string
|
||||||
|
Check(ctx *AutheliaCtx, s *session.UserSession) (success bool)
|
||||||
|
Failure(ctx *AutheliaCtx, s *session.UserSession)
|
||||||
|
|
||||||
|
IsAuthentication() bool
|
||||||
|
IsElevation() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler {
|
||||||
|
handler := &RequiredLevelProtectedEndpointHandler{
|
||||||
|
level: level,
|
||||||
|
statusCode: statusCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler.statusCode == 0 {
|
||||||
|
handler.statusCode = fasthttp.StatusForbidden
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler.level == 0 {
|
||||||
|
handler.level = authentication.OneFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequiredLevelProtectedEndpointHandler struct {
|
||||||
|
level authentication.Level
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequiredLevelProtectedEndpointHandler) Name() string {
|
||||||
|
return fmt.Sprintf("required_level(%s)", h.level)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||||||
|
return s.AuthenticationLevel >= h.level
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler {
|
||||||
|
return &OTPEscalationProtectedEndpointHandler{
|
||||||
|
config: &config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OTPEscalationProtectedEndpointHandler struct {
|
||||||
|
config *OTPEscalationProtectedEndpointConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OTPEscalationProtectedEndpointHandler) Name() string {
|
||||||
|
return "one_time_password"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||||||
|
if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor {
|
||||||
|
ctx.Logger.
|
||||||
|
WithField("username", s.Username).
|
||||||
|
Warning("User elevated session check has skipped due to 2FA")
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Elevations.User == nil {
|
||||||
|
ctx.Logger.
|
||||||
|
WithField("username", s.Username).
|
||||||
|
Warning("User elevated session has not been created")
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Elevations.User.Expires.Before(ctx.Clock.Now()) {
|
||||||
|
ctx.Logger.
|
||||||
|
WithField("username", s.Username).
|
||||||
|
WithField("expires", s.Elevations.User.Expires).
|
||||||
|
Debug("User elevated session IP did not match the request")
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) {
|
||||||
|
ctx.Logger.
|
||||||
|
WithField("username", s.Username).
|
||||||
|
WithField("elevation_ip", s.Elevations.User.RemoteIP).
|
||||||
|
Warning("User elevated session IP did not match the request")
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) {
|
||||||
|
if s.Elevations.User != nil {
|
||||||
|
// If we make it here we should destroy the elevation data.
|
||||||
|
s.Elevations.User = nil
|
||||||
|
|
||||||
|
if err := ctx.SaveSession(*s); err != nil {
|
||||||
|
ctx.Logger.WithError(err).Error("Error session after user elevated session failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
|
||||||
|
func Require1FA(next RequestHandler) RequestHandler {
|
||||||
|
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden))
|
||||||
|
|
||||||
|
return handler(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require2FA requires the user to have authenticated with two-factor authentication.
|
||||||
|
func Require2FA(next RequestHandler) RequestHandler {
|
||||||
|
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden))
|
||||||
|
|
||||||
|
return handler(next)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
|
||||||
|
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
|
||||||
|
return func(ctx *AutheliaCtx) {
|
||||||
|
session, err := ctx.GetSession()
|
||||||
|
|
||||||
|
if err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
||||||
|
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,315 @@
|
||||||
|
package middlewares_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/authentication"
|
||||||
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
|
"github.com/authelia/authelia/v4/internal/mocks"
|
||||||
|
"github.com/authelia/authelia/v4/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleSetStatus(code int) middlewares.RequestHandler {
|
||||||
|
return func(ctx *middlewares.AutheliaCtx) {
|
||||||
|
if err := ctx.ReplyJSON(middlewares.ErrorResponse{Status: "OK", Message: "Endpoint Response"}, 0); err != nil {
|
||||||
|
ctx.Logger.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.SetStatusCode(code)
|
||||||
|
ctx.Response.Header.Set("X-Testing-Success", "yes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedEndpointRequiredLevel(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
level authentication.Level
|
||||||
|
have session.UserSession
|
||||||
|
expected, status int
|
||||||
|
expectedAuth, expectedElevate bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1FAWithAuthenticatedUser2FAShould200OK",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusOK,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1FAWithAuthenticatedUser1FAShould200OK",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusOK,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.OneFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1FAWithAuthenticatedUser2FAShould301Found",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusFound,
|
||||||
|
status: fasthttp.StatusFound,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1FAWithAuthenticatedUser1FAShould301Found",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusFound,
|
||||||
|
status: fasthttp.StatusFound,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.OneFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusUnauthorized,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
AuthenticationLevel: authentication.NotAuthenticated,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||||
|
level: authentication.OneFactor,
|
||||||
|
expected: fasthttp.StatusUnauthorized,
|
||||||
|
status: fasthttp.StatusFound,
|
||||||
|
have: session.UserSession{
|
||||||
|
AuthenticationLevel: authentication.NotAuthenticated,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2FAWithNotAuthenticatedUserShould401Unauthenticated",
|
||||||
|
level: authentication.TwoFactor,
|
||||||
|
expected: fasthttp.StatusForbidden,
|
||||||
|
expectedAuth: true,
|
||||||
|
status: fasthttp.StatusFound,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.OneFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
err := mock.Ctx.SaveSession(tc.have)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var h middlewares.RequestHandler
|
||||||
|
|
||||||
|
switch tc.level {
|
||||||
|
case authentication.OneFactor:
|
||||||
|
h = middlewares.Require1FA(handleSetStatus(tc.status))
|
||||||
|
default:
|
||||||
|
h = middlewares.Require2FA(handleSetStatus(tc.status))
|
||||||
|
}
|
||||||
|
|
||||||
|
h(mock.Ctx)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||||
|
|
||||||
|
if tc.expected == tc.status {
|
||||||
|
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":%t,"elevation":%t}`, fasthttp.StatusMessage(tc.expected), tc.expectedAuth, tc.expectedElevate), string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtectedEndpointOTP(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
characters int
|
||||||
|
emailexp, sessionexp time.Duration
|
||||||
|
skip2fa bool
|
||||||
|
have session.UserSession
|
||||||
|
ip net.IP
|
||||||
|
time time.Time
|
||||||
|
expected, status int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ReturnUnauthorizedForAnonymous",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: false,
|
||||||
|
expected: fasthttp.StatusUnauthorized,
|
||||||
|
status: fasthttp.StatusFound,
|
||||||
|
have: session.UserSession{
|
||||||
|
AuthenticationLevel: authentication.NotAuthenticated,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return200OKWhen2FASkipAndUserIs2FAd",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: true,
|
||||||
|
expected: fasthttp.StatusOK,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HandleEscalationEmailWhen2FASkipAndUserIs1FAd",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: true,
|
||||||
|
expected: fasthttp.StatusForbidden,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.OneFactor,
|
||||||
|
Elevations: session.Elevations{User: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HandleEscalationEmailWhenUserIs2FAd",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: false,
|
||||||
|
expected: fasthttp.StatusForbidden,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
Elevations: session.Elevations{User: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return200OKWhenUserIsEscalated",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: false,
|
||||||
|
expected: fasthttp.StatusOK,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
ip: net.ParseIP("192.168.0.1"),
|
||||||
|
time: time.Unix(1671322337, 0),
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
Elevations: session.Elevations{
|
||||||
|
User: &session.Elevation{
|
||||||
|
RemoteIP: net.ParseIP("192.168.0.1"),
|
||||||
|
Expires: time.Unix(1671322347, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Return403ForbiddenWhenUserIsEscalatedButInvalidIP",
|
||||||
|
characters: 10,
|
||||||
|
emailexp: time.Minute,
|
||||||
|
sessionexp: time.Minute,
|
||||||
|
skip2fa: false,
|
||||||
|
expected: fasthttp.StatusForbidden,
|
||||||
|
status: fasthttp.StatusOK,
|
||||||
|
ip: net.ParseIP("192.168.0.2"),
|
||||||
|
time: time.Unix(1671322337, 0),
|
||||||
|
have: session.UserSession{
|
||||||
|
Username: "john",
|
||||||
|
DisplayName: "John Wick",
|
||||||
|
Emails: []string{"john.wick@notmessingaround.com"},
|
||||||
|
AuthenticationLevel: authentication.TwoFactor,
|
||||||
|
Elevations: session.Elevations{
|
||||||
|
User: &session.Elevation{
|
||||||
|
RemoteIP: net.ParseIP("192.168.0.1"),
|
||||||
|
Expires: time.Unix(1671322347, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, tc.ip.String())
|
||||||
|
|
||||||
|
err := mock.Ctx.SaveSession(tc.have)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !tc.time.IsZero() {
|
||||||
|
mock.Clock.Set(tc.time)
|
||||||
|
mock.Ctx.Clock = &mock.Clock
|
||||||
|
}
|
||||||
|
|
||||||
|
h := middlewares.ProtectedEndpoint(middlewares.NewOTPEscalationProtectedEndpointHandler(middlewares.OTPEscalationProtectedEndpointConfig{
|
||||||
|
Characters: tc.characters,
|
||||||
|
EmailValidityDuration: tc.emailexp,
|
||||||
|
EscalationValidityDuration: tc.sessionexp,
|
||||||
|
Skip2FA: tc.skip2fa,
|
||||||
|
}))(handleSetStatus(tc.status))
|
||||||
|
|
||||||
|
h(mock.Ctx)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case tc.have.IsAnonymous():
|
||||||
|
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, fmt.Sprintf(`{"status":"KO","message":"%s","authentication":false,"elevation":false}`, fasthttp.StatusMessage(tc.expected)), string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
case tc.skip2fa && tc.have.AuthenticationLevel == authentication.TwoFactor:
|
||||||
|
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
case tc.have.Elevations.User == nil || mock.Ctx.Clock.Now().After(tc.have.Elevations.User.Expires) || !tc.ip.Equal(tc.have.Elevations.User.RemoteIP):
|
||||||
|
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, `{"status":"KO","message":"Forbidden","authentication":false,"elevation":true}`, string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte(nil), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
default:
|
||||||
|
assert.Equal(t, tc.expected, mock.Ctx.Response.StatusCode())
|
||||||
|
assert.Equal(t, `{"status":"OK","message":"Endpoint Response"}`, string(mock.Ctx.Response.Body()))
|
||||||
|
assert.Equal(t, []byte("yes"), mock.Ctx.Response.Header.Peek("X-Testing-Success"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,43 +0,0 @@
|
||||||
package middlewares
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/authentication"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
|
|
||||||
func Require1FA(next RequestHandler) RequestHandler {
|
|
||||||
return func(ctx *AutheliaCtx) {
|
|
||||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.OneFactor {
|
|
||||||
ctx.ReplyForbidden()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Require2FA requires the user to have authenticated with two-factor authentication.
|
|
||||||
func Require2FA(next RequestHandler) RequestHandler {
|
|
||||||
return func(ctx *AutheliaCtx) {
|
|
||||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
|
||||||
ctx.ReplyForbidden()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
|
|
||||||
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
|
|
||||||
return func(ctx *AutheliaCtx) {
|
|
||||||
if session, err := ctx.GetSession(); err != nil || session.AuthenticationLevel < authentication.TwoFactor {
|
|
||||||
ctx.SetAuthenticationErrorJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -110,6 +110,20 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConsumeOneTimePassword mocks base method.
|
||||||
|
func (m *MockStorage) ConsumeOneTimePassword(arg0 context.Context, arg1 *model.OneTimePassword) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "ConsumeOneTimePassword", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumeOneTimePassword indicates an expected call of ConsumeOneTimePassword.
|
||||||
|
func (mr *MockStorageMockRecorder) ConsumeOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).ConsumeOneTimePassword), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// DeactivateOAuth2Session mocks base method.
|
// DeactivateOAuth2Session mocks base method.
|
||||||
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
|
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -299,6 +313,21 @@ func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadOneTimePassword mocks base method.
|
||||||
|
func (m *MockStorage) LoadOneTimePassword(arg0 context.Context, arg1, arg2, arg3 string) (*model.OneTimePassword, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "LoadOneTimePassword", arg0, arg1, arg2, arg3)
|
||||||
|
ret0, _ := ret[0].(*model.OneTimePassword)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadOneTimePassword indicates an expected call of LoadOneTimePassword.
|
||||||
|
func (mr *MockStorageMockRecorder) LoadOneTimePassword(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOneTimePassword", reflect.TypeOf((*MockStorage)(nil).LoadOneTimePassword), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
// LoadPreferred2FAMethod mocks base method.
|
// LoadPreferred2FAMethod mocks base method.
|
||||||
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
|
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -521,6 +550,20 @@ func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, ar
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RevokeOneTimePassword mocks base method.
|
||||||
|
func (m *MockStorage) RevokeOneTimePassword(arg0 context.Context, arg1 uuid.UUID, arg2 model.IP) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RevokeOneTimePassword", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeOneTimePassword indicates an expected call of RevokeOneTimePassword.
|
||||||
|
func (mr *MockStorageMockRecorder) RevokeOneTimePassword(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOneTimePassword", reflect.TypeOf((*MockStorage)(nil).RevokeOneTimePassword), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// Rollback mocks base method.
|
// Rollback mocks base method.
|
||||||
func (m *MockStorage) Rollback(arg0 context.Context) error {
|
func (m *MockStorage) Rollback(arg0 context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -662,6 +705,21 @@ func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveOneTimePassword mocks base method.
|
||||||
|
func (m *MockStorage) SaveOneTimePassword(arg0 context.Context, arg1 model.OneTimePassword) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SaveOneTimePassword", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveOneTimePassword indicates an expected call of SaveOneTimePassword.
|
||||||
|
func (mr *MockStorageMockRecorder) SaveOneTimePassword(arg0, arg1 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOneTimePassword", reflect.TypeOf((*MockStorage)(nil).SaveOneTimePassword), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// SavePreferred2FAMethod mocks base method.
|
// SavePreferred2FAMethod mocks base method.
|
||||||
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
|
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OTPIntentElevateUserSession = "eus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewOneTimePassword returns a new OneTimePassword.
|
||||||
|
func NewOneTimePassword(publicID uuid.UUID, username, intent string, iat, exp time.Time, ip net.IP, value []byte) (otp OneTimePassword) {
|
||||||
|
return OneTimePassword{
|
||||||
|
PublicID: publicID,
|
||||||
|
IssuedAt: iat,
|
||||||
|
ExpiresAt: exp,
|
||||||
|
Username: username,
|
||||||
|
Intent: intent,
|
||||||
|
IssuedIP: NewIP(ip),
|
||||||
|
Password: value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OneTimePassword represents special one time passwords stored in the database.
|
||||||
|
type OneTimePassword struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
PublicID uuid.UUID `db:"public_id"`
|
||||||
|
Signature string `db:"signature"`
|
||||||
|
IssuedAt time.Time `db:"iat"`
|
||||||
|
IssuedIP IP `db:"issued_ip"`
|
||||||
|
ExpiresAt time.Time `db:"exp"`
|
||||||
|
Username string `db:"username"`
|
||||||
|
Intent string `db:"intent"`
|
||||||
|
Consumed sql.NullTime `db:"consumed"`
|
||||||
|
ConsumedIP NullIP `db:"consumed_ip"`
|
||||||
|
Revoked sql.NullTime `db:"revoked"`
|
||||||
|
RevokedIP NullIP `db:"revoked_ip"`
|
||||||
|
Password []byte `db:"password"`
|
||||||
|
}
|
|
@ -176,7 +176,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
||||||
|
|
||||||
middleware2FA := middlewares.NewBridgeBuilder(*config, providers).
|
middleware2FA := middlewares.NewBridgeBuilder(*config, providers).
|
||||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||||
WithPostMiddlewares(middlewares.Require2FAWithAPIResponse).
|
WithPostMiddlewares(middlewares.Require2FA).
|
||||||
Build()
|
Build()
|
||||||
|
|
||||||
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))
|
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/random"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
@ -138,7 +139,9 @@ type TLSServerContext struct {
|
||||||
func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) {
|
func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLSServerContext, err error) {
|
||||||
serverContext = new(TLSServerContext)
|
serverContext = new(TLSServerContext)
|
||||||
|
|
||||||
providers := middlewares.Providers{}
|
providers := middlewares.Providers{
|
||||||
|
Random: random.NewMathematical(),
|
||||||
|
}
|
||||||
|
|
||||||
providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath})
|
providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: configuration.Notifier.TemplatePath})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fasthttp/session/v2"
|
"github.com/fasthttp/session/v2"
|
||||||
|
@ -44,6 +45,8 @@ type UserSession struct {
|
||||||
PasswordResetUsername *string
|
PasswordResetUsername *string
|
||||||
|
|
||||||
RefreshTTL time.Time
|
RefreshTTL time.Time
|
||||||
|
|
||||||
|
Elevations Elevations
|
||||||
}
|
}
|
||||||
|
|
||||||
// TOTP holds the TOTP registration session data.
|
// TOTP holds the TOTP registration session data.
|
||||||
|
@ -68,3 +71,13 @@ type Identity struct {
|
||||||
Email string
|
Email string
|
||||||
DisplayName string
|
DisplayName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Elevations struct {
|
||||||
|
User *Elevation
|
||||||
|
}
|
||||||
|
|
||||||
|
type Elevation struct {
|
||||||
|
ID int
|
||||||
|
RemoteIP net.IP
|
||||||
|
Expires time.Time
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ const (
|
||||||
tableAuthenticationLogs = "authentication_logs"
|
tableAuthenticationLogs = "authentication_logs"
|
||||||
tableDuoDevices = "duo_devices"
|
tableDuoDevices = "duo_devices"
|
||||||
tableIdentityVerification = "identity_verification"
|
tableIdentityVerification = "identity_verification"
|
||||||
|
tableOneTimePassword = "one_time_password"
|
||||||
tableTOTPConfigurations = "totp_configurations"
|
tableTOTPConfigurations = "totp_configurations"
|
||||||
tableUserOpaqueIdentifier = "user_opaque_identifier"
|
tableUserOpaqueIdentifier = "user_opaque_identifier"
|
||||||
tableUserPreferences = "user_preferences"
|
tableUserPreferences = "user_preferences"
|
||||||
|
|
|
@ -419,7 +419,7 @@ CREATE TABLE IF NOT EXISTS oauth2_consent_preconfiguration (
|
||||||
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
revoked BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
scopes TEXT NOT NULL,
|
scopes TEXT NOT NULL,
|
||||||
audience TEXT NULL,
|
audience TEXT NULL,
|
||||||
CONSTRAINT "oauth2_consent_preconfiguration_subject_fkey"
|
CONSTRAINT oauth2_consent_preconfiguration_subject_fkey
|
||||||
FOREIGN KEY (subject)
|
FOREIGN KEY (subject)
|
||||||
REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT
|
REFERENCES user_opaque_identifier (identifier) ON UPDATE CASCADE ON DELETE RESTRICT
|
||||||
);
|
);
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
DROP TABLE IF EXISTS one_time_password;
|
||||||
|
DROP TABLE IF EXISTS user_elevated_session;
|
|
@ -0,0 +1,30 @@
|
||||||
|
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||||
|
public_id CHAR(36) NOT NULL,
|
||||||
|
signature VARCHAR(128) NOT NULL,
|
||||||
|
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
|
exp TIMESTAMP NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL,
|
||||||
|
intent VARCHAR(100) NOT NULL,
|
||||||
|
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
revoked TIMESTAMP NULL DEFAULT NULL,
|
||||||
|
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
password BLOB NOT NULL
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX one_time_password_signature ON one_time_password (signature);
|
||||||
|
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||||
|
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
created_ip VARCHAR(39) NOT NULL,
|
||||||
|
method VARCHAR(10) NOT NULL,
|
||||||
|
method_id INTEGER NULL,
|
||||||
|
expires TIMESTAMP NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_520_ci;
|
||||||
|
|
||||||
|
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -0,0 +1,30 @@
|
||||||
|
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||||
|
id SERIAL CONSTRAINT one_time_password_pkey PRIMARY KEY,
|
||||||
|
public_id CHAR(36) NOT NULL,
|
||||||
|
signature VARCHAR(128) NOT NULL,
|
||||||
|
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
|
exp TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL,
|
||||||
|
intent VARCHAR(100) NOT NULL,
|
||||||
|
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
revoked TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||||
|
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
password BYTEA NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
|
||||||
|
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||||
|
id SERIAL CONSTRAINT user_elevated_session_pkey PRIMARY KEY,
|
||||||
|
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
created_ip VARCHAR(39) NOT NULL,
|
||||||
|
method VARCHAR(10) NOT NULL,
|
||||||
|
method_id INTEGER NULL,
|
||||||
|
expires TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -0,0 +1,30 @@
|
||||||
|
CREATE TABLE IF NOT EXISTS one_time_password (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
public_id CHAR(36) NOT NULL,
|
||||||
|
signature VARCHAR(128) NOT NULL,
|
||||||
|
iat DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
issued_ip VARCHAR(39) NOT NULL,
|
||||||
|
exp DATETIME NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL,
|
||||||
|
intent VARCHAR(100) NOT NULL,
|
||||||
|
consumed DATETIME NULL DEFAULT NULL,
|
||||||
|
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
revoked DATETIME NULL DEFAULT NULL,
|
||||||
|
revoked_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||||
|
password BLOB NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX one_time_password_lookup_key ON one_time_password (signature, username);
|
||||||
|
CREATE INDEX one_time_password_lookup ON one_time_password (signature, username);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_elevated_session (
|
||||||
|
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||||
|
created DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
created_ip VARCHAR(39) NOT NULL,
|
||||||
|
method VARCHAR(10) NOT NULL,
|
||||||
|
method_id INTEGER NULL,
|
||||||
|
expires DATETIME NOT NULL,
|
||||||
|
username VARCHAR(100) NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX user_elevated_session_username ON user_elevated_session (username);
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// This is the latest schema version for the purpose of tests.
|
// This is the latest schema version for the purpose of tests.
|
||||||
LatestVersion = 10
|
LatestVersion = 11
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
||||||
|
|
|
@ -32,6 +32,11 @@ type Provider interface {
|
||||||
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
|
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
|
||||||
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
||||||
|
|
||||||
|
SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error)
|
||||||
|
ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error)
|
||||||
|
RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error)
|
||||||
|
LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error)
|
||||||
|
|
||||||
SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error)
|
SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error)
|
||||||
UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error)
|
UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt sql.NullTime) (err error)
|
||||||
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
||||||
|
|
|
@ -29,6 +29,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
||||||
driverName: driverName,
|
driverName: driverName,
|
||||||
config: config,
|
config: config,
|
||||||
errOpen: err,
|
errOpen: err,
|
||||||
|
|
||||||
|
keys: SQLProviderKeys{
|
||||||
|
encryption: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
|
||||||
|
},
|
||||||
|
|
||||||
log: logging.Logger(),
|
log: logging.Logger(),
|
||||||
|
|
||||||
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
|
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
|
||||||
|
@ -38,6 +43,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
||||||
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
||||||
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
|
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
|
||||||
|
|
||||||
|
sqlInsertOneTimePassword: fmt.Sprintf(queryFmtInsertOTP, tableOneTimePassword),
|
||||||
|
sqlConsumeOneTimePassword: fmt.Sprintf(queryFmtConsumeOTP, tableOneTimePassword),
|
||||||
|
sqlRevokeOneTimePassword: fmt.Sprintf(queryFmtRevokeOTP, tableOneTimePassword),
|
||||||
|
sqlSelectOneTimePassword: fmt.Sprintf(queryFmtSelectOTP, tableOneTimePassword),
|
||||||
|
|
||||||
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
||||||
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
|
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
|
||||||
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
|
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
|
||||||
|
@ -141,12 +151,15 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
||||||
type SQLProvider struct {
|
type SQLProvider struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
key [32]byte
|
key [32]byte
|
||||||
|
|
||||||
name string
|
name string
|
||||||
driverName string
|
driverName string
|
||||||
schema string
|
schema string
|
||||||
config *schema.Configuration
|
config *schema.Configuration
|
||||||
errOpen error
|
errOpen error
|
||||||
|
|
||||||
|
keys SQLProviderKeys
|
||||||
|
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
|
|
||||||
// Table: authentication_logs.
|
// Table: authentication_logs.
|
||||||
|
@ -158,6 +171,12 @@ type SQLProvider struct {
|
||||||
sqlConsumeIdentityVerification string
|
sqlConsumeIdentityVerification string
|
||||||
sqlSelectIdentityVerification string
|
sqlSelectIdentityVerification string
|
||||||
|
|
||||||
|
// Table: one_time_password.
|
||||||
|
sqlInsertOneTimePassword string
|
||||||
|
sqlConsumeOneTimePassword string
|
||||||
|
sqlRevokeOneTimePassword string
|
||||||
|
sqlSelectOneTimePassword string
|
||||||
|
|
||||||
// Table: totp_configurations.
|
// Table: totp_configurations.
|
||||||
sqlUpsertTOTPConfig string
|
sqlUpsertTOTPConfig string
|
||||||
sqlDeleteTOTPConfig string
|
sqlDeleteTOTPConfig string
|
||||||
|
@ -274,6 +293,12 @@ type SQLProvider struct {
|
||||||
sqlFmtRenameTable string
|
sqlFmtRenameTable string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SQLProviderKeys are the cryptography keys used by a SQLProvider.
|
||||||
|
type SQLProviderKeys struct {
|
||||||
|
encryption [32]byte
|
||||||
|
signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
// Close the underlying database connection.
|
// Close the underlying database connection.
|
||||||
func (p *SQLProvider) Close() (err error) {
|
func (p *SQLProvider) Close() (err error) {
|
||||||
return p.db.Close()
|
return p.db.Close()
|
||||||
|
@ -313,14 +338,19 @@ func (p *SQLProvider) StartupCheck() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
|
switch err = p.SchemaMigrate(ctx, true, SchemaLatest); err {
|
||||||
|
case nil:
|
||||||
|
break
|
||||||
case ErrSchemaAlreadyUpToDate:
|
case ErrSchemaAlreadyUpToDate:
|
||||||
p.log.Infof("Storage schema is already up to date")
|
p.log.Infof("Storage schema is already up to date")
|
||||||
return nil
|
|
||||||
case nil:
|
|
||||||
return nil
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("error during schema migrate: %w", err)
|
return fmt.Errorf("error during schema migrate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.keys.signature, err = p.getKeySigHMAC(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize the hmac signature key during startup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeginTX begins a transaction.
|
// BeginTX begins a transaction.
|
||||||
|
@ -816,6 +846,62 @@ func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveOneTimePassword saves a one time password to the database after generating the signature.
|
||||||
|
func (p *SQLProvider) SaveOneTimePassword(ctx context.Context, otp model.OneTimePassword) (signature string, err error) {
|
||||||
|
signature = p.hmacSignature([]byte(otp.Username), []byte(otp.Intent), otp.Password)
|
||||||
|
|
||||||
|
if otp.Password, err = p.encrypt(otp.Password); err != nil {
|
||||||
|
return "", fmt.Errorf("error encrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertOneTimePassword,
|
||||||
|
otp.PublicID, signature, otp.IssuedAt, otp.IssuedIP, otp.ExpiresAt,
|
||||||
|
otp.Username, otp.Intent, otp.Password); err != nil {
|
||||||
|
return "", fmt.Errorf("error inserting one time password for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConsumeOneTimePassword consumes a one time password using the signature.
|
||||||
|
func (p *SQLProvider) ConsumeOneTimePassword(ctx context.Context, otp *model.OneTimePassword) (err error) {
|
||||||
|
if _, err = p.db.ExecContext(ctx, p.sqlConsumeOneTimePassword, otp.Consumed, otp.ConsumedIP, otp.Signature); err != nil {
|
||||||
|
return fmt.Errorf("error updating one time password (consume): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeOneTimePassword revokes a one time password using the public ID.
|
||||||
|
func (p *SQLProvider) RevokeOneTimePassword(ctx context.Context, publicID uuid.UUID, ip model.IP) (err error) {
|
||||||
|
if _, err = p.db.ExecContext(ctx, p.sqlRevokeOneTimePassword, ip, publicID); err != nil {
|
||||||
|
return fmt.Errorf("error updating one time password (revoke): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadOneTimePassword loads a one time password from the database given a username, intent, and password.
|
||||||
|
func (p *SQLProvider) LoadOneTimePassword(ctx context.Context, username, intent, password string) (otp *model.OneTimePassword, err error) {
|
||||||
|
otp = &model.OneTimePassword{}
|
||||||
|
|
||||||
|
signature := p.hmacSignature([]byte(username), []byte(intent), []byte(password))
|
||||||
|
|
||||||
|
if err = p.db.GetContext(ctx, otp, p.sqlSelectOneTimePassword, signature, username); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("error selecting one time password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if otp.Password, err = p.decrypt(otp.Password); err != nil {
|
||||||
|
return nil, fmt.Errorf("error decrypting the one time password value for user '%s' with signature '%s': %w", otp.Username, otp.Signature, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return otp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
|
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
|
||||||
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
||||||
if config.Secret, err = p.encrypt(config.Secret); err != nil {
|
if config.Secret, err = p.encrypt(config.Secret); err != nil {
|
||||||
|
|
|
@ -52,6 +52,11 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo
|
||||||
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
||||||
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
|
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
|
||||||
|
|
||||||
|
provider.sqlInsertOneTimePassword = provider.db.Rebind(provider.sqlInsertOneTimePassword)
|
||||||
|
provider.sqlConsumeOneTimePassword = provider.db.Rebind(provider.sqlConsumeOneTimePassword)
|
||||||
|
provider.sqlRevokeOneTimePassword = provider.db.Rebind(provider.sqlRevokeOneTimePassword)
|
||||||
|
provider.sqlSelectOneTimePassword = provider.db.Rebind(provider.sqlSelectOneTimePassword)
|
||||||
|
|
||||||
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
||||||
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
|
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
|
||||||
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
|
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
|
||||||
|
|
|
@ -3,7 +3,10 @@ package storage
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -16,10 +19,10 @@ import (
|
||||||
|
|
||||||
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
|
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
|
||||||
// by this command to encrypt the values again and update them using a transaction.
|
// by this command to encrypt the values again and update them using a transaction.
|
||||||
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string) (err error) {
|
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, rawKey string) (err error) {
|
||||||
skey := sha256.Sum256([]byte(key))
|
key := sha256.Sum256([]byte(rawKey))
|
||||||
|
|
||||||
if bytes.Equal(skey[:], p.key[:]) {
|
if bytes.Equal(key[:], p.keys.encryption[:]) {
|
||||||
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
|
return fmt.Errorf("error changing the storage encryption key: the old key and the new key are the same")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,6 +36,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
||||||
}
|
}
|
||||||
|
|
||||||
encChangeFuncs := []EncryptionChangeKeyFunc{
|
encChangeFuncs := []EncryptionChangeKeyFunc{
|
||||||
|
schemaEncryptionChangeKeyOneTimePassword,
|
||||||
schemaEncryptionChangeKeyTOTP,
|
schemaEncryptionChangeKeyTOTP,
|
||||||
schemaEncryptionChangeKeyWebAuthn,
|
schemaEncryptionChangeKeyWebAuthn,
|
||||||
}
|
}
|
||||||
|
@ -47,8 +51,10 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
||||||
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
|
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encChangeFuncs = append(encChangeFuncs, schemaEncryptionChangeKeyEncryption)
|
||||||
|
|
||||||
for _, encChangeFunc := range encChangeFuncs {
|
for _, encChangeFunc := range encChangeFuncs {
|
||||||
if err = encChangeFunc(ctx, p, tx, skey); err != nil {
|
if err = encChangeFunc(ctx, p, tx, key); err != nil {
|
||||||
if rerr := tx.Rollback(); rerr != nil {
|
if rerr := tx.Rollback(); rerr != nil {
|
||||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
||||||
}
|
}
|
||||||
|
@ -57,14 +63,6 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, key string)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.setNewEncryptionCheckValue(ctx, tx, &skey); err != nil {
|
|
||||||
if rerr := tx.Rollback(); rerr != nil {
|
|
||||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rerr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("rollback due to error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +87,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
encCheckFuncs := []EncryptionCheckKeyFunc{
|
encCheckFuncs := []EncryptionCheckKeyFunc{
|
||||||
|
schemaEncryptionCheckKeyOneTimePassword,
|
||||||
schemaEncryptionCheckKeyTOTP,
|
schemaEncryptionCheckKeyTOTP,
|
||||||
schemaEncryptionCheckKeyWebAuthn,
|
schemaEncryptionCheckKeyWebAuthn,
|
||||||
}
|
}
|
||||||
|
@ -103,6 +102,8 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
|
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encCheckFuncs = append(encCheckFuncs, schemaEncryptionCheckKeyEncryption)
|
||||||
|
|
||||||
for _, encCheckFunc := range encCheckFuncs {
|
for _, encCheckFunc := range encCheckFuncs {
|
||||||
table, tableResult := encCheckFunc(ctx, p)
|
table, tableResult := encCheckFunc(ctx, p)
|
||||||
|
|
||||||
|
@ -113,6 +114,46 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func schemaEncryptionChangeKeyOneTimePassword(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||||
|
var count int
|
||||||
|
|
||||||
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableOneTimePassword)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := make([]encOneTimePassword, 0, count)
|
||||||
|
|
||||||
|
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("error selecting one time passwords: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateOTPEncryptedData, tableOneTimePassword))
|
||||||
|
|
||||||
|
for _, c := range configs {
|
||||||
|
if c.OTP, err = provider.decrypt(c.OTP); err != nil {
|
||||||
|
return fmt.Errorf("error decrypting one time password with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.OTP, err = utils.Encrypt(c.OTP, &key); err != nil {
|
||||||
|
return fmt.Errorf("error encrypting one time password with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = tx.ExecContext(ctx, query, c.OTP, c.ID); err != nil {
|
||||||
|
return fmt.Errorf("error updating one time password with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||||
var count int
|
var count int
|
||||||
|
|
||||||
|
@ -134,7 +175,7 @@ func schemaEncryptionChangeKeyTOTP(ctx context.Context, provider *SQLProvider, t
|
||||||
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations))
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateTOTPConfigurationEncryptedData, tableTOTPConfigurations))
|
||||||
|
|
||||||
for _, c := range configs {
|
for _, c := range configs {
|
||||||
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
|
if c.Secret, err = provider.decrypt(c.Secret); err != nil {
|
||||||
|
@ -231,6 +272,77 @@ func schemaEncryptionChangeKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func schemaEncryptionChangeKeyEncryption(ctx context.Context, provider *SQLProvider, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||||
|
var count int
|
||||||
|
|
||||||
|
if err = tx.GetContext(ctx, &count, fmt.Sprintf(queryFmtSelectRowCount, tableEncryption)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := make([]encEncryption, 0, count)
|
||||||
|
|
||||||
|
if err = tx.SelectContext(ctx, &configs, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("error selecting encyption value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := provider.db.Rebind(fmt.Sprintf(queryFmtUpdateEncryptionEncryptedData, tableEncryption))
|
||||||
|
|
||||||
|
for _, c := range configs {
|
||||||
|
if c.Value, err = provider.decrypt(c.Value); err != nil {
|
||||||
|
return fmt.Errorf("error decrypting encyption value with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Value, err = utils.Encrypt(c.Value, &key); err != nil {
|
||||||
|
return fmt.Errorf("error encrypting encyption value with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = tx.ExecContext(ctx, query, c.Value, c.ID); err != nil {
|
||||||
|
return fmt.Errorf("error updating encyption value with id '%d': %w", c.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaEncryptionCheckKeyOneTimePassword(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||||
|
var (
|
||||||
|
rows *sqlx.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectOTPEncryptedData, tableOneTimePassword)); err != nil {
|
||||||
|
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting one time passwords: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
var config encOneTimePassword
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
result.Total++
|
||||||
|
|
||||||
|
if err = rows.StructScan(&config); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return tableOneTimePassword, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning one time password to struct: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = provider.decrypt(config.OTP); err != nil {
|
||||||
|
result.Invalid++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return tableOneTimePassword, result
|
||||||
|
}
|
||||||
|
|
||||||
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
func schemaEncryptionCheckKeyTOTP(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||||
var (
|
var (
|
||||||
rows *sqlx.Rows
|
rows *sqlx.Rows
|
||||||
|
@ -326,12 +438,77 @@ func schemaEncryptionCheckKeyOpenIDConnect(typeOAuth2Session OAuth2SessionType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func schemaEncryptionCheckKeyEncryption(ctx context.Context, provider *SQLProvider) (table string, result EncryptionValidationTableResult) {
|
||||||
|
var (
|
||||||
|
rows *sqlx.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if rows, err = provider.db.QueryxContext(ctx, fmt.Sprintf(queryFmtSelectEncryptionEncryptedData, tableEncryption)); err != nil {
|
||||||
|
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error selecting encryption values: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
var config encEncryption
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
result.Total++
|
||||||
|
|
||||||
|
if err = rows.StructScan(&config); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return tableEncryption, EncryptionValidationTableResult{Error: fmt.Errorf("error scanning encryption value to struct: %w", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = provider.decrypt(config.Value); err != nil {
|
||||||
|
result.Invalid++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = rows.Close()
|
||||||
|
|
||||||
|
return tableEncryption, result
|
||||||
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||||
return utils.Encrypt(clearText, &p.key)
|
return utils.Encrypt(clearText, &p.keys.encryption)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
||||||
return utils.Decrypt(cipherText, &p.key)
|
return utils.Decrypt(cipherText, &p.keys.encryption)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) hmacSignature(values ...[]byte) string {
|
||||||
|
h := hmac.New(sha512.New, p.keys.signature)
|
||||||
|
|
||||||
|
for i := 0; i < len(values); i++ {
|
||||||
|
h.Write(values[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) getKeySigHMAC(ctx context.Context) (key []byte, err error) {
|
||||||
|
if key, err = p.getEncryptionValue(ctx, "hmac_signature_key"); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
key = make([]byte, sha512.BlockSize)
|
||||||
|
|
||||||
|
_, err = rand.Read(key)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate hmac key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = p.setEncryptionValue(ctx, "hmac_signature_key", key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
|
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
|
||||||
|
@ -345,6 +522,18 @@ func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (valu
|
||||||
return p.decrypt(encryptedValue)
|
return p.decrypt(encryptedValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SQLProvider) setEncryptionValue(ctx context.Context, name string, value []byte) (err error) {
|
||||||
|
if value, err = p.encrypt(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, name, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
|
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, conn SQLXConnection, key *[32]byte) (err error) {
|
||||||
valueClearText, err := uuid.NewRandom()
|
valueClearText, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -71,6 +71,36 @@ const (
|
||||||
WHERE jti = ?;`
|
WHERE jti = ?;`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryFmtSelectOTP = `
|
||||||
|
SELECT id, public_id, signature, iat, issued_ip, exp, username intent, consumed, consumed_ip, revoked, revoked_ip, password
|
||||||
|
FROM %s
|
||||||
|
WHERE signature = ? AND username = ?;`
|
||||||
|
|
||||||
|
queryFmtInsertOTP = `
|
||||||
|
INSERT INTO %s (public_id, signature, iat, issued_ip, exp, username, intent, password)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
|
||||||
|
|
||||||
|
queryFmtConsumeOTP = `
|
||||||
|
UPDATE %s
|
||||||
|
SET consumed = ?, consumed_ip = ?
|
||||||
|
WHERE signature = ?;`
|
||||||
|
|
||||||
|
queryFmtRevokeOTP = `
|
||||||
|
UPDATE %s
|
||||||
|
SET revoked = CURRENT_TIMESTAMP, revoked_ip = ?
|
||||||
|
WHERE public_id = ?;`
|
||||||
|
|
||||||
|
queryFmtSelectOTPEncryptedData = `
|
||||||
|
SELECT id, password
|
||||||
|
FROM %s;`
|
||||||
|
|
||||||
|
queryFmtUpdateOTPEncryptedData = `
|
||||||
|
UPDATE %s
|
||||||
|
SET password = ?
|
||||||
|
WHERE id = ?;`
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
queryFmtSelectTOTPConfiguration = `
|
queryFmtSelectTOTPConfiguration = `
|
||||||
SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret
|
SELECT id, created_at, last_used_at, username, issuer, algorithm, digits, period, secret
|
||||||
|
@ -87,8 +117,7 @@ const (
|
||||||
SELECT id, secret
|
SELECT id, secret
|
||||||
FROM %s;`
|
FROM %s;`
|
||||||
|
|
||||||
//nolint:gosec // These are not hardcoded credentials it's a query to obtain credentials.
|
queryFmtUpdateTOTPConfigurationEncryptedData = `
|
||||||
queryFmtUpdateTOTPConfigurationSecret = `
|
|
||||||
UPDATE %s
|
UPDATE %s
|
||||||
SET secret = ?
|
SET secret = ?
|
||||||
WHERE id = ?;`
|
WHERE id = ?;`
|
||||||
|
@ -241,6 +270,14 @@ const (
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
ON CONFLICT (name)
|
ON CONFLICT (name)
|
||||||
DO UPDATE SET value = $2;`
|
DO UPDATE SET value = $2;`
|
||||||
|
|
||||||
|
queryFmtSelectEncryptionEncryptedData = `
|
||||||
|
SELECT id, value
|
||||||
|
FROM %s;`
|
||||||
|
|
||||||
|
queryFmtUpdateEncryptionEncryptedData = `
|
||||||
|
UPDATE %s
|
||||||
|
SET value = ?`
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -251,7 +251,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
|
||||||
|
|
||||||
if migration.Version == 1 && migration.Up {
|
if migration.Version == 1 && migration.Up {
|
||||||
// Add the schema encryption value if upgrading to v1.
|
// Add the schema encryption value if upgrading to v1.
|
||||||
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
|
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.keys.encryption); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,14 +32,24 @@ type encOAuth2Session struct {
|
||||||
Session []byte `db:"session_data"`
|
Session []byte `db:"session_data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type encOneTimePassword struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
OTP []byte `db:"otp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type encTOTPConfiguration struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Secret []byte `db:"secret"`
|
||||||
|
}
|
||||||
|
|
||||||
type encWebAuthnDevice struct {
|
type encWebAuthnDevice struct {
|
||||||
ID int `db:"id"`
|
ID int `db:"id"`
|
||||||
PublicKey []byte `db:"public_key"`
|
PublicKey []byte `db:"public_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type encTOTPConfiguration struct {
|
type encEncryption struct {
|
||||||
ID int `db:"id" json:"-"`
|
ID int `db:"id"`
|
||||||
Secret []byte `db:"secret" json:"-"`
|
Value []byte `db:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptionValidationResult contains information about the success of a schema encryption validation.
|
// EncryptionValidationResult contains information about the success of a schema encryption validation.
|
||||||
|
|
Loading…
Reference in New Issue