From 9a8c6602dd44ee1231fc6d258df4f3738571e1f8 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 21 Jan 2022 10:46:13 +1100 Subject: [PATCH] fix(middlewares): smart delay on reset password (#2767) This adds a smart delay on reset password attempts to prevent username enumeration. Additionally utilizes crypto rand instead of math rand. It also moves the timing delay functionality into its own handler func. --- internal/handlers/const.go | 6 -- internal/handlers/handler_firstfactor.go | 54 +----------- internal/handlers/handler_firstfactor_test.go | 84 ++++--------------- internal/handlers/handler_register_totp.go | 2 +- .../handlers/handler_register_u2f_step1.go | 2 +- .../handlers/handler_reset_password_step1.go | 3 +- internal/middlewares/authelia_context.go | 3 +- internal/middlewares/identity_verification.go | 12 ++- .../middlewares/identity_verification_test.go | 12 +-- internal/middlewares/timing_attack_delay.go | 72 ++++++++++++++++ .../middlewares/timing_attack_delay_test.go | 65 ++++++++++++++ internal/server/server.go | 3 +- 12 files changed, 179 insertions(+), 139 deletions(-) create mode 100644 internal/middlewares/timing_attack_delay.go create mode 100644 internal/middlewares/timing_attack_delay_test.go diff --git a/internal/handlers/const.go b/internal/handlers/const.go index d3bc4e986..ba6e3ce18 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -61,12 +61,6 @@ const ( testUsername = "john" ) -const ( - loginDelayMovingAverageWindow = 10 - loginDelayMinimumDelayMilliseconds = float64(250) - loginDelayMaximumRandomDelayMilliseconds = int64(85) -) - // Duo constants. const ( allow = "allow" diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go index 2b02f1f23..83f7e8ff2 100644 --- a/internal/handlers/handler_firstfactor.go +++ b/internal/handlers/handler_firstfactor.go @@ -2,9 +2,6 @@ package handlers import ( "errors" - "math" - "math/rand" - "sync" "time" "github.com/authelia/authelia/v4/internal/middlewares" @@ -12,61 +9,16 @@ import ( "github.com/authelia/authelia/v4/internal/session" ) -func movingAverageIteration(value time.Duration, successful bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) float64 { - mutex.Lock() - if successful { - (*execDurationMovingAverage)[*movingAverageCursor] = value - *movingAverageCursor = (*movingAverageCursor + 1) % loginDelayMovingAverageWindow - } - - var sum int64 - - for _, v := range *execDurationMovingAverage { - sum += v.Milliseconds() - } - mutex.Unlock() - - return float64(sum / loginDelayMovingAverageWindow) -} - -func calculateActualDelay(ctx *middlewares.AutheliaCtx, execDuration time.Duration, avgExecDurationMs float64, successful *bool) float64 { - randomDelayMs := float64(rand.Int63n(loginDelayMaximumRandomDelayMilliseconds)) //nolint:gosec // TODO: Consider use of crypto/rand, this should be benchmarked and measured first. - totalDelayMs := math.Max(avgExecDurationMs, loginDelayMinimumDelayMilliseconds) + randomDelayMs - actualDelayMs := math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0) - ctx.Logger.Tracef("Attempt successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", *successful, execDuration.Milliseconds(), int64(avgExecDurationMs), int64(randomDelayMs), int64(totalDelayMs), int64(actualDelayMs)) - - return actualDelayMs -} - -func delayToPreventTimingAttacks(ctx *middlewares.AutheliaCtx, requestTime time.Time, successful *bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) { - execDuration := time.Since(requestTime) - avgExecDurationMs := movingAverageIteration(execDuration, *successful, movingAverageCursor, execDurationMovingAverage, mutex) - actualDelayMs := calculateActualDelay(ctx, execDuration, avgExecDurationMs, successful) - time.Sleep(time.Duration(actualDelayMs) * time.Millisecond) -} - // FirstFactorPost is the handler performing the first factory. //nolint:gocyclo // TODO: Consider refactoring time permitting. -func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middlewares.RequestHandler { - var execDurationMovingAverage = make([]time.Duration, loginDelayMovingAverageWindow) - - var movingAverageCursor = 0 - - var mutex = &sync.Mutex{} - - for i := range execDurationMovingAverage { - execDurationMovingAverage[i] = msInitialDelay * time.Millisecond - } - - rand.Seed(time.Now().UnixNano()) - +func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.RequestHandler { return func(ctx *middlewares.AutheliaCtx) { var successful bool requestTime := time.Now() - if delayEnabled { - defer delayToPreventTimingAttacks(ctx, requestTime, &successful, &movingAverageCursor, &execDurationMovingAverage, mutex) + if delayFunc != nil { + defer delayFunc(ctx.Logger, requestTime, &successful) } bodyJSON := firstFactorRequestBody{} diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index fe26016a7..5828883a7 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -2,9 +2,7 @@ package handlers import ( "fmt" - "sync" "testing" - "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -33,7 +31,7 @@ func (s *FirstFactorSuite) TearDownTest() { } func (s *FirstFactorSuite) TestShouldFailIfBodyIsNil() { - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // No body assert.Equal(s.T(), "Failed to parse 1FA request body: unable to parse body: unexpected end of JSON input", s.mock.Hook.LastEntry().Message) @@ -45,7 +43,7 @@ func (s *FirstFactorSuite) TestShouldFailIfBodyIsInBadFormat() { s.mock.Ctx.Request.SetBodyString(`{ "username": "test" }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) assert.Equal(s.T(), "Failed to parse 1FA request body: unable to validate body: password: non zero value required", s.mock.Hook.LastEntry().Message) s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") @@ -73,7 +71,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() { "password": "hello", "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) assert.Equal(s.T(), "Unsuccessful 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message) s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") @@ -102,7 +100,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) } func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCredentials() { @@ -128,7 +126,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) } func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() { @@ -152,7 +150,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() { "password": "hello", "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) assert.Equal(s.T(), "Could not obtain profile details during 1FA authentication for user 'test': failed", s.mock.Hook.LastEntry().Message) s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") @@ -174,7 +172,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() { "password": "hello", "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) assert.Equal(s.T(), "Unable to mark 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message) s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.") @@ -205,7 +203,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() { "password": "hello", "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) @@ -246,7 +244,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() { "requestMethod": "GET", "keepMeLoggedIn": false }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) @@ -290,7 +288,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess "requestMethod": "GET", "keepMeLoggedIn": true }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) @@ -360,7 +358,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenNoTarget "requestMethod": "GET", "keepMeLoggedIn": false }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"}) @@ -381,7 +379,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenURLIsUns "targetURL": "http://notsafe.local" }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"}) @@ -404,7 +402,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenNoTargetURLProvidedA "keepMeLoggedIn": false }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. s.mock.Assert200OK(s.T(), nil) @@ -436,7 +434,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi "keepMeLoggedIn": false }`) - FirstFactorPost(0, false)(s.mock.Ctx) + FirstFactorPost(nil)(s.mock.Ctx) // Respond with 200. s.mock.Assert200OK(s.T(), nil) @@ -446,57 +444,3 @@ func TestFirstFactorSuite(t *testing.T) { suite.Run(t, new(FirstFactorSuite)) suite.Run(t, new(FirstFactorRedirectionSuite)) } - -func TestFirstFactorDelayAverages(t *testing.T) { - execDuration := time.Millisecond * 500 - oneSecond := time.Millisecond * 1000 - durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond} - cursor := 0 - mutex := &sync.Mutex{} - avgExecDuration := movingAverageIteration(execDuration, false, &cursor, &durations, mutex) - assert.Equal(t, avgExecDuration, float64(1000)) - - execDurations := []time.Duration{ - time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, - time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, - time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, - } - - current := float64(1000) - - // Execute at 500ms for 12 requests. - for _, execDuration = range execDurations { - // Should not dip below 500, and should decrease in value by 50 each iteration. - if current > 500 { - current -= 50 - } - - avgExecDuration := movingAverageIteration(execDuration, true, &cursor, &durations, mutex) - assert.Equal(t, avgExecDuration, current) - } -} - -func TestFirstFactorDelayCalculations(t *testing.T) { - mock := mocks.NewMockAutheliaCtx(t) - successful := false - - execDuration := 500 * time.Millisecond - avgExecDurationMs := 1000.0 - expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds()) - - for i := 0; i < 100; i++ { - delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful) - assert.True(t, delay >= expectedMinimumDelayMs) - assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds)) - } - - execDuration = 5 * time.Millisecond - avgExecDurationMs = 5.0 - expectedMinimumDelayMs = loginDelayMinimumDelayMilliseconds - float64(execDuration.Milliseconds()) - - for i := 0; i < 100; i++ { - delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful) - assert.True(t, delay >= expectedMinimumDelayMs) - assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds)) - } -} diff --git a/internal/handlers/handler_register_totp.go b/internal/handlers/handler_register_totp.go index 4542d2644..d3e5f3550 100644 --- a/internal/handlers/handler_register_totp.go +++ b/internal/handlers/handler_register_totp.go @@ -33,7 +33,7 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle TargetEndpoint: "/one-time-password/register", ActionClaim: ActionTOTPRegistration, IdentityRetrieverFunc: identityRetrieverFromSession, -}) +}, nil) func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { var ( diff --git a/internal/handlers/handler_register_u2f_step1.go b/internal/handlers/handler_register_u2f_step1.go index b779cf919..fc2304b91 100644 --- a/internal/handlers/handler_register_u2f_step1.go +++ b/internal/handlers/handler_register_u2f_step1.go @@ -21,7 +21,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew TargetEndpoint: "/security-key/register", ActionClaim: ActionU2FRegistration, IdentityRetrieverFunc: identityRetrieverFromSession, -}) +}, nil) func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { if ctx.XForwardedProto() == nil { diff --git a/internal/handlers/handler_reset_password_step1.go b/internal/handlers/handler_reset_password_step1.go index 7f460bd8c..75cff5f0e 100644 --- a/internal/handlers/handler_reset_password_step1.go +++ b/internal/handlers/handler_reset_password_step1.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "fmt" + "time" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/session" @@ -40,7 +41,7 @@ var ResetPasswordIdentityStart = middlewares.IdentityVerificationStart(middlewar TargetEndpoint: "/reset-password/step2", ActionClaim: ActionResetPassword, IdentityRetrieverFunc: identityRetrieverFromStorage, -}) +}, middlewares.TimingAttackDelay(10, 250, 85, time.Millisecond*500)) func resetPasswordIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { userSession := ctx.GetSession() diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 1552cd3d7..ec4b0dce7 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -13,13 +13,14 @@ import ( "github.com/valyala/fasthttp" "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/utils" ) // NewRequestLogger create a new request logger for the given request. func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry { - return logrus.WithFields(logrus.Fields{ + return logging.Logger().WithFields(logrus.Fields{ "method": string(ctx.Method()), "path": string(ctx.Path()), "remote_ip": ctx.RemoteIP().String(), diff --git a/internal/middlewares/identity_verification.go b/internal/middlewares/identity_verification.go index f416c7566..22ea47936 100644 --- a/internal/middlewares/identity_verification.go +++ b/internal/middlewares/identity_verification.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "time" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" @@ -13,12 +14,19 @@ import ( ) // IdentityVerificationStart the handler for initiating the identity validation process. -func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandler { +func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc TimingAttackDelayFunc) RequestHandler { if args.IdentityRetrieverFunc == nil { panic(fmt.Errorf("Identity verification requires an identity retriever")) } return func(ctx *AutheliaCtx) { + requestTime := time.Now() + success := false + + if delayFunc != nil { + defer delayFunc(ctx.Logger, requestTime, &success) + } + identity, err := args.IdentityRetrieverFunc(ctx) if err != nil { // In that case we reply ok to avoid user enumeration. @@ -106,6 +114,8 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle return } + success = true + ctx.ReplyOK() } } diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index 7c4a1f893..38e592c74 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -44,7 +44,7 @@ func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) { return nil, fmt.Errorf("User does not have any email") } - middlewares.IdentityVerificationStart(newArgs(retriever))(mock.Ctx) + middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message) @@ -61,7 +61,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) { Return(fmt.Errorf("cannot save")) args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args)(mock.Ctx) + middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message) @@ -84,7 +84,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) { Return(fmt.Errorf("no notif")) args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args)(mock.Ctx) + middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "no notif", mock.Hook.LastEntry().Message) @@ -102,7 +102,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) { Return(nil) args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args)(mock.Ctx) + middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message) @@ -120,7 +120,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) { Return(nil) args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args)(mock.Ctx) + middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message) @@ -142,7 +142,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { Return(nil) args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args)(mock.Ctx) + middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) diff --git a/internal/middlewares/timing_attack_delay.go b/internal/middlewares/timing_attack_delay.go new file mode 100644 index 000000000..60339b0ce --- /dev/null +++ b/internal/middlewares/timing_attack_delay.go @@ -0,0 +1,72 @@ +package middlewares + +import ( + "crypto/rand" + "math" + "math/big" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay. +type TimingAttackDelayFunc func(logger *logrus.Entry, requestTime time.Time, successful *bool) + +// TimingAttackDelay creates a new standard timing delay func. +func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration) TimingAttackDelayFunc { + var ( + mutex = &sync.Mutex{} + cursor = 0 + ) + + execDurationMovingAverage := make([]time.Duration, history) + + for i := range execDurationMovingAverage { + execDurationMovingAverage[i] = initialDelay + } + + return func(logger *logrus.Entry, requestTime time.Time, successful *bool) { + successfulValue := false + if successful != nil { + successfulValue = *successful + } + + execDuration := time.Since(requestTime) + execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex) + actualDelayMs := calculateActualDelay(logger, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue) + time.Sleep(time.Duration(actualDelayMs) * time.Millisecond) + } +} + +func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 { + mutex.Lock() + + var sum int64 + + for _, v := range *movingAvg { + sum += v.Milliseconds() + } + + if successful { + (*movingAvg)[*cursor] = value + *cursor = (*cursor + 1) % history + } + + mutex.Unlock() + + return float64(sum / int64(history)) +} + +func calculateActualDelay(logger *logrus.Entry, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) { + randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs)) + if err != nil { + return float64(maxRandomMs) + } + + totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64()) + actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0) + logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs)) + + return actualDelayMs +} diff --git a/internal/middlewares/timing_attack_delay_test.go b/internal/middlewares/timing_attack_delay_test.go new file mode 100644 index 000000000..88243de12 --- /dev/null +++ b/internal/middlewares/timing_attack_delay_test.go @@ -0,0 +1,65 @@ +package middlewares + +import ( + "sync" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/authelia/authelia/v4/internal/logging" +) + +func TestTimingAttackDelayAverages(t *testing.T) { + execDuration := time.Millisecond * 500 + oneSecond := time.Millisecond * 1000 + durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond} + cursor := 0 + mutex := &sync.Mutex{} + avgExecDuration := movingAverageIteration(execDuration, 10, false, &cursor, &durations, mutex) + assert.Equal(t, avgExecDuration, float64(1000)) + + execDurations := []time.Duration{ + time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, + time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, + time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, + } + + current := float64(1000) + + // Execute at 500ms for 12 requests. + for _, execDuration = range execDurations { + avgExecDuration = movingAverageIteration(execDuration, 10, true, &cursor, &durations, mutex) + assert.Equal(t, avgExecDuration, current) + + // Should not dip below 500, and should decrease in value by 50 each iteration. + if current > 500 { + current -= 50 + } + } +} + +func TestTimingAttackDelayCalculations(t *testing.T) { + execDuration := 500 * time.Millisecond + avgExecDurationMs := 1000.0 + expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds()) + + logger := logging.Logger().WithFields(logrus.Fields{}) + + for i := 0; i < 100; i++ { + delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false) + assert.True(t, delay >= expectedMinimumDelayMs) + assert.True(t, delay <= expectedMinimumDelayMs+float64(85)) + } + + execDuration = 5 * time.Millisecond + avgExecDurationMs = 5.0 + expectedMinimumDelayMs = 250 - float64(execDuration.Milliseconds()) + + for i := 0; i < 100; i++ { + delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false) + assert.True(t, delay >= expectedMinimumDelayMs) + assert.True(t, delay <= expectedMinimumDelayMs+float64(250)) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 52181c9ca..d75bb0d4f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strconv" + "time" duoapi "github.com/duosecurity/duo_api_golang" "github.com/fasthttp/router" @@ -69,7 +70,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr r.POST("/api/checks/safe-redirection", autheliaMiddleware(handlers.CheckSafeRedirection)) - r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(1000, true))) + r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(middlewares.TimingAttackDelay(10, 250, 85, time.Second)))) r.POST("/api/logout", autheliaMiddleware(handlers.LogoutPost)) // Only register endpoints if forgot password is not disabled.