fix(middlewares): smart delay on reset password (#2767)
This adds a smart delay on reset password attempts to prevent username enumeration. Additionally utilizes crypto rand instead of math rand. It also moves the timing delay functionality into its own handler func.pull/2810/head^2
parent
97a862e81a
commit
9a8c6602dd
|
@ -61,12 +61,6 @@ const (
|
|||
testUsername = "john"
|
||||
)
|
||||
|
||||
const (
|
||||
loginDelayMovingAverageWindow = 10
|
||||
loginDelayMinimumDelayMilliseconds = float64(250)
|
||||
loginDelayMaximumRandomDelayMilliseconds = int64(85)
|
||||
)
|
||||
|
||||
// Duo constants.
|
||||
const (
|
||||
allow = "allow"
|
||||
|
|
|
@ -2,9 +2,6 @@ package handlers
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
|
@ -12,61 +9,16 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
func movingAverageIteration(value time.Duration, successful bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) float64 {
|
||||
mutex.Lock()
|
||||
if successful {
|
||||
(*execDurationMovingAverage)[*movingAverageCursor] = value
|
||||
*movingAverageCursor = (*movingAverageCursor + 1) % loginDelayMovingAverageWindow
|
||||
}
|
||||
|
||||
var sum int64
|
||||
|
||||
for _, v := range *execDurationMovingAverage {
|
||||
sum += v.Milliseconds()
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
return float64(sum / loginDelayMovingAverageWindow)
|
||||
}
|
||||
|
||||
func calculateActualDelay(ctx *middlewares.AutheliaCtx, execDuration time.Duration, avgExecDurationMs float64, successful *bool) float64 {
|
||||
randomDelayMs := float64(rand.Int63n(loginDelayMaximumRandomDelayMilliseconds)) //nolint:gosec // TODO: Consider use of crypto/rand, this should be benchmarked and measured first.
|
||||
totalDelayMs := math.Max(avgExecDurationMs, loginDelayMinimumDelayMilliseconds) + randomDelayMs
|
||||
actualDelayMs := math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
|
||||
ctx.Logger.Tracef("Attempt successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", *successful, execDuration.Milliseconds(), int64(avgExecDurationMs), int64(randomDelayMs), int64(totalDelayMs), int64(actualDelayMs))
|
||||
|
||||
return actualDelayMs
|
||||
}
|
||||
|
||||
func delayToPreventTimingAttacks(ctx *middlewares.AutheliaCtx, requestTime time.Time, successful *bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) {
|
||||
execDuration := time.Since(requestTime)
|
||||
avgExecDurationMs := movingAverageIteration(execDuration, *successful, movingAverageCursor, execDurationMovingAverage, mutex)
|
||||
actualDelayMs := calculateActualDelay(ctx, execDuration, avgExecDurationMs, successful)
|
||||
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
|
||||
}
|
||||
|
||||
// FirstFactorPost is the handler performing the first factory.
|
||||
//nolint:gocyclo // TODO: Consider refactoring time permitting.
|
||||
func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middlewares.RequestHandler {
|
||||
var execDurationMovingAverage = make([]time.Duration, loginDelayMovingAverageWindow)
|
||||
|
||||
var movingAverageCursor = 0
|
||||
|
||||
var mutex = &sync.Mutex{}
|
||||
|
||||
for i := range execDurationMovingAverage {
|
||||
execDurationMovingAverage[i] = msInitialDelay * time.Millisecond
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.RequestHandler {
|
||||
return func(ctx *middlewares.AutheliaCtx) {
|
||||
var successful bool
|
||||
|
||||
requestTime := time.Now()
|
||||
|
||||
if delayEnabled {
|
||||
defer delayToPreventTimingAttacks(ctx, requestTime, &successful, &movingAverageCursor, &execDurationMovingAverage, mutex)
|
||||
if delayFunc != nil {
|
||||
defer delayFunc(ctx.Logger, requestTime, &successful)
|
||||
}
|
||||
|
||||
bodyJSON := firstFactorRequestBody{}
|
||||
|
|
|
@ -2,9 +2,7 @@ package handlers
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -33,7 +31,7 @@ func (s *FirstFactorSuite) TearDownTest() {
|
|||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldFailIfBodyIsNil() {
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// No body
|
||||
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to parse body: unexpected end of JSON input", s.mock.Hook.LastEntry().Message)
|
||||
|
@ -45,7 +43,7 @@ func (s *FirstFactorSuite) TestShouldFailIfBodyIsInBadFormat() {
|
|||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
"username": "test"
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to validate body: password: non zero value required", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -73,7 +71,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Unsuccessful 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -102,7 +100,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
|
|||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCredentials() {
|
||||
|
@ -128,7 +126,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
|
|||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
|
||||
|
@ -152,7 +150,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Could not obtain profile details during 1FA authentication for user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -174,7 +172,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Unable to mark 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -205,7 +203,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -246,7 +244,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() {
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -290,7 +288,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -360,7 +358,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenNoTarget
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
|
||||
|
@ -381,7 +379,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenURLIsUns
|
|||
"targetURL": "http://notsafe.local"
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
|
||||
|
@ -404,7 +402,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenNoTargetURLProvidedA
|
|||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), nil)
|
||||
|
@ -436,7 +434,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
|
|||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), nil)
|
||||
|
@ -446,57 +444,3 @@ func TestFirstFactorSuite(t *testing.T) {
|
|||
suite.Run(t, new(FirstFactorSuite))
|
||||
suite.Run(t, new(FirstFactorRedirectionSuite))
|
||||
}
|
||||
|
||||
func TestFirstFactorDelayAverages(t *testing.T) {
|
||||
execDuration := time.Millisecond * 500
|
||||
oneSecond := time.Millisecond * 1000
|
||||
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
|
||||
cursor := 0
|
||||
mutex := &sync.Mutex{}
|
||||
avgExecDuration := movingAverageIteration(execDuration, false, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, float64(1000))
|
||||
|
||||
execDurations := []time.Duration{
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
}
|
||||
|
||||
current := float64(1000)
|
||||
|
||||
// Execute at 500ms for 12 requests.
|
||||
for _, execDuration = range execDurations {
|
||||
// Should not dip below 500, and should decrease in value by 50 each iteration.
|
||||
if current > 500 {
|
||||
current -= 50
|
||||
}
|
||||
|
||||
avgExecDuration := movingAverageIteration(execDuration, true, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, current)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstFactorDelayCalculations(t *testing.T) {
|
||||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
successful := false
|
||||
|
||||
execDuration := 500 * time.Millisecond
|
||||
avgExecDurationMs := 1000.0
|
||||
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
|
||||
}
|
||||
|
||||
execDuration = 5 * time.Millisecond
|
||||
avgExecDurationMs = 5.0
|
||||
expectedMinimumDelayMs = loginDelayMinimumDelayMilliseconds - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle
|
|||
TargetEndpoint: "/one-time-password/register",
|
||||
ActionClaim: ActionTOTPRegistration,
|
||||
IdentityRetrieverFunc: identityRetrieverFromSession,
|
||||
})
|
||||
}, nil)
|
||||
|
||||
func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
var (
|
||||
|
|
|
@ -21,7 +21,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew
|
|||
TargetEndpoint: "/security-key/register",
|
||||
ActionClaim: ActionU2FRegistration,
|
||||
IdentityRetrieverFunc: identityRetrieverFromSession,
|
||||
})
|
||||
}, nil)
|
||||
|
||||
func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
if ctx.XForwardedProto() == nil {
|
||||
|
|
|
@ -3,6 +3,7 @@ package handlers
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
|
@ -40,7 +41,7 @@ var ResetPasswordIdentityStart = middlewares.IdentityVerificationStart(middlewar
|
|||
TargetEndpoint: "/reset-password/step2",
|
||||
ActionClaim: ActionResetPassword,
|
||||
IdentityRetrieverFunc: identityRetrieverFromStorage,
|
||||
})
|
||||
}, middlewares.TimingAttackDelay(10, 250, 85, time.Millisecond*500))
|
||||
|
||||
func resetPasswordIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
userSession := ctx.GetSession()
|
||||
|
|
|
@ -13,13 +13,14 @@ import (
|
|||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// NewRequestLogger create a new request logger for the given request.
|
||||
func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry {
|
||||
return logrus.WithFields(logrus.Fields{
|
||||
return logging.Logger().WithFields(logrus.Fields{
|
||||
"method": string(ctx.Method()),
|
||||
"path": string(ctx.Path()),
|
||||
"remote_ip": ctx.RemoteIP().String(),
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
|
@ -13,12 +14,19 @@ import (
|
|||
)
|
||||
|
||||
// IdentityVerificationStart the handler for initiating the identity validation process.
|
||||
func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandler {
|
||||
func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc TimingAttackDelayFunc) RequestHandler {
|
||||
if args.IdentityRetrieverFunc == nil {
|
||||
panic(fmt.Errorf("Identity verification requires an identity retriever"))
|
||||
}
|
||||
|
||||
return func(ctx *AutheliaCtx) {
|
||||
requestTime := time.Now()
|
||||
success := false
|
||||
|
||||
if delayFunc != nil {
|
||||
defer delayFunc(ctx.Logger, requestTime, &success)
|
||||
}
|
||||
|
||||
identity, err := args.IdentityRetrieverFunc(ctx)
|
||||
if err != nil {
|
||||
// In that case we reply ok to avoid user enumeration.
|
||||
|
@ -106,6 +114,8 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
|
|||
return
|
||||
}
|
||||
|
||||
success = true
|
||||
|
||||
ctx.ReplyOK()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) {
|
|||
return nil, fmt.Errorf("User does not have any email")
|
||||
}
|
||||
|
||||
middlewares.IdentityVerificationStart(newArgs(retriever))(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message)
|
||||
|
@ -61,7 +61,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
|
|||
Return(fmt.Errorf("cannot save"))
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message)
|
||||
|
@ -84,7 +84,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
|
|||
Return(fmt.Errorf("no notif"))
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
|
||||
|
@ -102,7 +102,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message)
|
||||
|
@ -120,7 +120,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message)
|
||||
|
@ -142,7 +142,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay.
|
||||
type TimingAttackDelayFunc func(logger *logrus.Entry, requestTime time.Time, successful *bool)
|
||||
|
||||
// TimingAttackDelay creates a new standard timing delay func.
|
||||
func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration) TimingAttackDelayFunc {
|
||||
var (
|
||||
mutex = &sync.Mutex{}
|
||||
cursor = 0
|
||||
)
|
||||
|
||||
execDurationMovingAverage := make([]time.Duration, history)
|
||||
|
||||
for i := range execDurationMovingAverage {
|
||||
execDurationMovingAverage[i] = initialDelay
|
||||
}
|
||||
|
||||
return func(logger *logrus.Entry, requestTime time.Time, successful *bool) {
|
||||
successfulValue := false
|
||||
if successful != nil {
|
||||
successfulValue = *successful
|
||||
}
|
||||
|
||||
execDuration := time.Since(requestTime)
|
||||
execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex)
|
||||
actualDelayMs := calculateActualDelay(logger, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue)
|
||||
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 {
|
||||
mutex.Lock()
|
||||
|
||||
var sum int64
|
||||
|
||||
for _, v := range *movingAvg {
|
||||
sum += v.Milliseconds()
|
||||
}
|
||||
|
||||
if successful {
|
||||
(*movingAvg)[*cursor] = value
|
||||
*cursor = (*cursor + 1) % history
|
||||
}
|
||||
|
||||
mutex.Unlock()
|
||||
|
||||
return float64(sum / int64(history))
|
||||
}
|
||||
|
||||
func calculateActualDelay(logger *logrus.Entry, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
|
||||
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
|
||||
if err != nil {
|
||||
return float64(maxRandomMs)
|
||||
}
|
||||
|
||||
totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64())
|
||||
actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
|
||||
logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs))
|
||||
|
||||
return actualDelayMs
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
func TestTimingAttackDelayAverages(t *testing.T) {
|
||||
execDuration := time.Millisecond * 500
|
||||
oneSecond := time.Millisecond * 1000
|
||||
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
|
||||
cursor := 0
|
||||
mutex := &sync.Mutex{}
|
||||
avgExecDuration := movingAverageIteration(execDuration, 10, false, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, float64(1000))
|
||||
|
||||
execDurations := []time.Duration{
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
}
|
||||
|
||||
current := float64(1000)
|
||||
|
||||
// Execute at 500ms for 12 requests.
|
||||
for _, execDuration = range execDurations {
|
||||
avgExecDuration = movingAverageIteration(execDuration, 10, true, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, current)
|
||||
|
||||
// Should not dip below 500, and should decrease in value by 50 each iteration.
|
||||
if current > 500 {
|
||||
current -= 50
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingAttackDelayCalculations(t *testing.T) {
|
||||
execDuration := 500 * time.Millisecond
|
||||
avgExecDurationMs := 1000.0
|
||||
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
|
||||
|
||||
logger := logging.Logger().WithFields(logrus.Fields{})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(85))
|
||||
}
|
||||
|
||||
execDuration = 5 * time.Millisecond
|
||||
avgExecDurationMs = 5.0
|
||||
expectedMinimumDelayMs = 250 - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(250))
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
duoapi "github.com/duosecurity/duo_api_golang"
|
||||
"github.com/fasthttp/router"
|
||||
|
@ -69,7 +70,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
|||
|
||||
r.POST("/api/checks/safe-redirection", autheliaMiddleware(handlers.CheckSafeRedirection))
|
||||
|
||||
r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(1000, true)))
|
||||
r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(middlewares.TimingAttackDelay(10, 250, 85, time.Second))))
|
||||
r.POST("/api/logout", autheliaMiddleware(handlers.LogoutPost))
|
||||
|
||||
// Only register endpoints if forgot password is not disabled.
|
||||
|
|
Loading…
Reference in New Issue