fix(middlewares): smart delay on reset password (#2767)

This adds a smart delay on reset password attempts to prevent username enumeration. Additionally utilizes crypto rand instead of math rand. It also moves the timing delay functionality into its own handler func.
pull/2810/head^2
James Elliott 2022-01-21 10:46:13 +11:00 committed by GitHub
parent 97a862e81a
commit 9a8c6602dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 179 additions and 139 deletions

View File

@ -61,12 +61,6 @@ const (
testUsername = "john" testUsername = "john"
) )
const (
loginDelayMovingAverageWindow = 10
loginDelayMinimumDelayMilliseconds = float64(250)
loginDelayMaximumRandomDelayMilliseconds = int64(85)
)
// Duo constants. // Duo constants.
const ( const (
allow = "allow" allow = "allow"

View File

@ -2,9 +2,6 @@ package handlers
import ( import (
"errors" "errors"
"math"
"math/rand"
"sync"
"time" "time"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
@ -12,61 +9,16 @@ import (
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
) )
func movingAverageIteration(value time.Duration, successful bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) float64 {
mutex.Lock()
if successful {
(*execDurationMovingAverage)[*movingAverageCursor] = value
*movingAverageCursor = (*movingAverageCursor + 1) % loginDelayMovingAverageWindow
}
var sum int64
for _, v := range *execDurationMovingAverage {
sum += v.Milliseconds()
}
mutex.Unlock()
return float64(sum / loginDelayMovingAverageWindow)
}
func calculateActualDelay(ctx *middlewares.AutheliaCtx, execDuration time.Duration, avgExecDurationMs float64, successful *bool) float64 {
randomDelayMs := float64(rand.Int63n(loginDelayMaximumRandomDelayMilliseconds)) //nolint:gosec // TODO: Consider use of crypto/rand, this should be benchmarked and measured first.
totalDelayMs := math.Max(avgExecDurationMs, loginDelayMinimumDelayMilliseconds) + randomDelayMs
actualDelayMs := math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
ctx.Logger.Tracef("Attempt successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", *successful, execDuration.Milliseconds(), int64(avgExecDurationMs), int64(randomDelayMs), int64(totalDelayMs), int64(actualDelayMs))
return actualDelayMs
}
func delayToPreventTimingAttacks(ctx *middlewares.AutheliaCtx, requestTime time.Time, successful *bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) {
execDuration := time.Since(requestTime)
avgExecDurationMs := movingAverageIteration(execDuration, *successful, movingAverageCursor, execDurationMovingAverage, mutex)
actualDelayMs := calculateActualDelay(ctx, execDuration, avgExecDurationMs, successful)
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
}
// FirstFactorPost is the handler performing the first factory. // FirstFactorPost is the handler performing the first factory.
//nolint:gocyclo // TODO: Consider refactoring time permitting. //nolint:gocyclo // TODO: Consider refactoring time permitting.
func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middlewares.RequestHandler { func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.RequestHandler {
var execDurationMovingAverage = make([]time.Duration, loginDelayMovingAverageWindow)
var movingAverageCursor = 0
var mutex = &sync.Mutex{}
for i := range execDurationMovingAverage {
execDurationMovingAverage[i] = msInitialDelay * time.Millisecond
}
rand.Seed(time.Now().UnixNano())
return func(ctx *middlewares.AutheliaCtx) { return func(ctx *middlewares.AutheliaCtx) {
var successful bool var successful bool
requestTime := time.Now() requestTime := time.Now()
if delayEnabled { if delayFunc != nil {
defer delayToPreventTimingAttacks(ctx, requestTime, &successful, &movingAverageCursor, &execDurationMovingAverage, mutex) defer delayFunc(ctx.Logger, requestTime, &successful)
} }
bodyJSON := firstFactorRequestBody{} bodyJSON := firstFactorRequestBody{}

View File

@ -2,9 +2,7 @@ package handlers
import ( import (
"fmt" "fmt"
"sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -33,7 +31,7 @@ func (s *FirstFactorSuite) TearDownTest() {
} }
func (s *FirstFactorSuite) TestShouldFailIfBodyIsNil() { func (s *FirstFactorSuite) TestShouldFailIfBodyIsNil() {
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// No body // No body
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to parse body: unexpected end of JSON input", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "Failed to parse 1FA request body: unable to parse body: unexpected end of JSON input", s.mock.Hook.LastEntry().Message)
@ -45,7 +43,7 @@ func (s *FirstFactorSuite) TestShouldFailIfBodyIsInBadFormat() {
s.mock.Ctx.Request.SetBodyString(`{ s.mock.Ctx.Request.SetBodyString(`{
"username": "test" "username": "test"
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to validate body: password: non zero value required", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "Failed to parse 1FA request body: unable to validate body: password: non zero value required", s.mock.Hook.LastEntry().Message)
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
@ -73,7 +71,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
"password": "hello", "password": "hello",
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
assert.Equal(s.T(), "Unsuccessful 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "Unsuccessful 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
@ -102,7 +100,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
} }
func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCredentials() { func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCredentials() {
@ -128,7 +126,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
} }
func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() { func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
@ -152,7 +150,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
"password": "hello", "password": "hello",
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
assert.Equal(s.T(), "Could not obtain profile details during 1FA authentication for user 'test': failed", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "Could not obtain profile details during 1FA authentication for user 'test': failed", s.mock.Hook.LastEntry().Message)
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
@ -174,7 +172,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() {
"password": "hello", "password": "hello",
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
assert.Equal(s.T(), "Unable to mark 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "Unable to mark 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
@ -205,7 +203,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() {
"password": "hello", "password": "hello",
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
@ -246,7 +244,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() {
"requestMethod": "GET", "requestMethod": "GET",
"keepMeLoggedIn": false "keepMeLoggedIn": false
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
@ -290,7 +288,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess
"requestMethod": "GET", "requestMethod": "GET",
"keepMeLoggedIn": true "keepMeLoggedIn": true
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
@ -360,7 +358,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenNoTarget
"requestMethod": "GET", "requestMethod": "GET",
"keepMeLoggedIn": false "keepMeLoggedIn": false
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"}) s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
@ -381,7 +379,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenURLIsUns
"targetURL": "http://notsafe.local" "targetURL": "http://notsafe.local"
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"}) s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
@ -404,7 +402,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenNoTargetURLProvidedA
"keepMeLoggedIn": false "keepMeLoggedIn": false
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
s.mock.Assert200OK(s.T(), nil) s.mock.Assert200OK(s.T(), nil)
@ -436,7 +434,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
"keepMeLoggedIn": false "keepMeLoggedIn": false
}`) }`)
FirstFactorPost(0, false)(s.mock.Ctx) FirstFactorPost(nil)(s.mock.Ctx)
// Respond with 200. // Respond with 200.
s.mock.Assert200OK(s.T(), nil) s.mock.Assert200OK(s.T(), nil)
@ -446,57 +444,3 @@ func TestFirstFactorSuite(t *testing.T) {
suite.Run(t, new(FirstFactorSuite)) suite.Run(t, new(FirstFactorSuite))
suite.Run(t, new(FirstFactorRedirectionSuite)) suite.Run(t, new(FirstFactorRedirectionSuite))
} }
func TestFirstFactorDelayAverages(t *testing.T) {
execDuration := time.Millisecond * 500
oneSecond := time.Millisecond * 1000
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
cursor := 0
mutex := &sync.Mutex{}
avgExecDuration := movingAverageIteration(execDuration, false, &cursor, &durations, mutex)
assert.Equal(t, avgExecDuration, float64(1000))
execDurations := []time.Duration{
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
}
current := float64(1000)
// Execute at 500ms for 12 requests.
for _, execDuration = range execDurations {
// Should not dip below 500, and should decrease in value by 50 each iteration.
if current > 500 {
current -= 50
}
avgExecDuration := movingAverageIteration(execDuration, true, &cursor, &durations, mutex)
assert.Equal(t, avgExecDuration, current)
}
}
func TestFirstFactorDelayCalculations(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
successful := false
execDuration := 500 * time.Millisecond
avgExecDurationMs := 1000.0
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
for i := 0; i < 100; i++ {
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
assert.True(t, delay >= expectedMinimumDelayMs)
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
}
execDuration = 5 * time.Millisecond
avgExecDurationMs = 5.0
expectedMinimumDelayMs = loginDelayMinimumDelayMilliseconds - float64(execDuration.Milliseconds())
for i := 0; i < 100; i++ {
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
assert.True(t, delay >= expectedMinimumDelayMs)
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
}
}

View File

@ -33,7 +33,7 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle
TargetEndpoint: "/one-time-password/register", TargetEndpoint: "/one-time-password/register",
ActionClaim: ActionTOTPRegistration, ActionClaim: ActionTOTPRegistration,
IdentityRetrieverFunc: identityRetrieverFromSession, IdentityRetrieverFunc: identityRetrieverFromSession,
}) }, nil)
func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
var ( var (

View File

@ -21,7 +21,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew
TargetEndpoint: "/security-key/register", TargetEndpoint: "/security-key/register",
ActionClaim: ActionU2FRegistration, ActionClaim: ActionU2FRegistration,
IdentityRetrieverFunc: identityRetrieverFromSession, IdentityRetrieverFunc: identityRetrieverFromSession,
}) }, nil)
func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
if ctx.XForwardedProto() == nil { if ctx.XForwardedProto() == nil {

View File

@ -3,6 +3,7 @@ package handlers
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"time"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
@ -40,7 +41,7 @@ var ResetPasswordIdentityStart = middlewares.IdentityVerificationStart(middlewar
TargetEndpoint: "/reset-password/step2", TargetEndpoint: "/reset-password/step2",
ActionClaim: ActionResetPassword, ActionClaim: ActionResetPassword,
IdentityRetrieverFunc: identityRetrieverFromStorage, IdentityRetrieverFunc: identityRetrieverFromStorage,
}) }, middlewares.TimingAttackDelay(10, 250, 85, time.Millisecond*500))
func resetPasswordIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { func resetPasswordIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
userSession := ctx.GetSession() userSession := ctx.GetSession()

View File

@ -13,13 +13,14 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// NewRequestLogger create a new request logger for the given request. // NewRequestLogger create a new request logger for the given request.
func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry { func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry {
return logrus.WithFields(logrus.Fields{ return logging.Logger().WithFields(logrus.Fields{
"method": string(ctx.Method()), "method": string(ctx.Method()),
"path": string(ctx.Path()), "path": string(ctx.Path()),
"remote_ip": ctx.RemoteIP().String(), "remote_ip": ctx.RemoteIP().String(),

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/google/uuid" "github.com/google/uuid"
@ -13,12 +14,19 @@ import (
) )
// IdentityVerificationStart the handler for initiating the identity validation process. // IdentityVerificationStart the handler for initiating the identity validation process.
func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandler { func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc TimingAttackDelayFunc) RequestHandler {
if args.IdentityRetrieverFunc == nil { if args.IdentityRetrieverFunc == nil {
panic(fmt.Errorf("Identity verification requires an identity retriever")) panic(fmt.Errorf("Identity verification requires an identity retriever"))
} }
return func(ctx *AutheliaCtx) { return func(ctx *AutheliaCtx) {
requestTime := time.Now()
success := false
if delayFunc != nil {
defer delayFunc(ctx.Logger, requestTime, &success)
}
identity, err := args.IdentityRetrieverFunc(ctx) identity, err := args.IdentityRetrieverFunc(ctx)
if err != nil { if err != nil {
// In that case we reply ok to avoid user enumeration. // In that case we reply ok to avoid user enumeration.
@ -106,6 +114,8 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
return return
} }
success = true
ctx.ReplyOK() ctx.ReplyOK()
} }
} }

View File

@ -44,7 +44,7 @@ func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) {
return nil, fmt.Errorf("User does not have any email") return nil, fmt.Errorf("User does not have any email")
} }
middlewares.IdentityVerificationStart(newArgs(retriever))(mock.Ctx) middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message) assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message)
@ -61,7 +61,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
Return(fmt.Errorf("cannot save")) Return(fmt.Errorf("cannot save"))
args := newArgs(defaultRetriever) args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args)(mock.Ctx) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message) assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message)
@ -84,7 +84,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
Return(fmt.Errorf("no notif")) Return(fmt.Errorf("no notif"))
args := newArgs(defaultRetriever) args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args)(mock.Ctx) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message) assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
@ -102,7 +102,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
Return(nil) Return(nil)
args := newArgs(defaultRetriever) args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args)(mock.Ctx) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message) assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message)
@ -120,7 +120,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
Return(nil) Return(nil)
args := newArgs(defaultRetriever) args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args)(mock.Ctx) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message) assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message)
@ -142,7 +142,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
Return(nil) Return(nil)
args := newArgs(defaultRetriever) args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args)(mock.Ctx) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())

View File

@ -0,0 +1,72 @@
package middlewares
import (
"crypto/rand"
"math"
"math/big"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay.
type TimingAttackDelayFunc func(logger *logrus.Entry, requestTime time.Time, successful *bool)
// TimingAttackDelay creates a new standard timing delay func.
func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration) TimingAttackDelayFunc {
var (
mutex = &sync.Mutex{}
cursor = 0
)
execDurationMovingAverage := make([]time.Duration, history)
for i := range execDurationMovingAverage {
execDurationMovingAverage[i] = initialDelay
}
return func(logger *logrus.Entry, requestTime time.Time, successful *bool) {
successfulValue := false
if successful != nil {
successfulValue = *successful
}
execDuration := time.Since(requestTime)
execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex)
actualDelayMs := calculateActualDelay(logger, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue)
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
}
}
func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 {
mutex.Lock()
var sum int64
for _, v := range *movingAvg {
sum += v.Milliseconds()
}
if successful {
(*movingAvg)[*cursor] = value
*cursor = (*cursor + 1) % history
}
mutex.Unlock()
return float64(sum / int64(history))
}
func calculateActualDelay(logger *logrus.Entry, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
if err != nil {
return float64(maxRandomMs)
}
totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64())
actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs))
return actualDelayMs
}

View File

@ -0,0 +1,65 @@
package middlewares
import (
"sync"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/authelia/authelia/v4/internal/logging"
)
func TestTimingAttackDelayAverages(t *testing.T) {
execDuration := time.Millisecond * 500
oneSecond := time.Millisecond * 1000
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
cursor := 0
mutex := &sync.Mutex{}
avgExecDuration := movingAverageIteration(execDuration, 10, false, &cursor, &durations, mutex)
assert.Equal(t, avgExecDuration, float64(1000))
execDurations := []time.Duration{
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
}
current := float64(1000)
// Execute at 500ms for 12 requests.
for _, execDuration = range execDurations {
avgExecDuration = movingAverageIteration(execDuration, 10, true, &cursor, &durations, mutex)
assert.Equal(t, avgExecDuration, current)
// Should not dip below 500, and should decrease in value by 50 each iteration.
if current > 500 {
current -= 50
}
}
}
func TestTimingAttackDelayCalculations(t *testing.T) {
execDuration := 500 * time.Millisecond
avgExecDurationMs := 1000.0
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
logger := logging.Logger().WithFields(logrus.Fields{})
for i := 0; i < 100; i++ {
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
assert.True(t, delay >= expectedMinimumDelayMs)
assert.True(t, delay <= expectedMinimumDelayMs+float64(85))
}
execDuration = 5 * time.Millisecond
avgExecDurationMs = 5.0
expectedMinimumDelayMs = 250 - float64(execDuration.Milliseconds())
for i := 0; i < 100; i++ {
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
assert.True(t, delay >= expectedMinimumDelayMs)
assert.True(t, delay <= expectedMinimumDelayMs+float64(250))
}
}

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"time"
duoapi "github.com/duosecurity/duo_api_golang" duoapi "github.com/duosecurity/duo_api_golang"
"github.com/fasthttp/router" "github.com/fasthttp/router"
@ -69,7 +70,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
r.POST("/api/checks/safe-redirection", autheliaMiddleware(handlers.CheckSafeRedirection)) r.POST("/api/checks/safe-redirection", autheliaMiddleware(handlers.CheckSafeRedirection))
r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(1000, true))) r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(middlewares.TimingAttackDelay(10, 250, 85, time.Second))))
r.POST("/api/logout", autheliaMiddleware(handlers.LogoutPost)) r.POST("/api/logout", autheliaMiddleware(handlers.LogoutPost))
// Only register endpoints if forgot password is not disabled. // Only register endpoints if forgot password is not disabled.