refactor(configuration): utilize time duration decode hook (#2938)

This enhances the existing time.Duration parser to allow multiple units, and implements a decode hook which can be used by koanf to decode string/integers into time.Durations as applicable.
pull/2673/head^2
James Elliott 2022-03-02 17:40:26 +11:00 committed by GitHub
parent d867fa1a63
commit 6276883f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 686 additions and 343 deletions

View File

@ -100,15 +100,31 @@ $ authelia validate-config --config configuration.yml
# Duration Notation Format # Duration Notation Format
We have implemented a string based notation for configuration options that take a duration. This section describes its We have implemented a string/integer based notation for configuration options that take a duration of time. This section
usage. You can use this implementation in: session for expiration, inactivity, and remember_me_duration; and regulation describes the implementation of this. You can use this implementation in various areas of configuration such as:
for ban_time, and find_time. This notation also supports just providing the number of seconds instead.
The notation is comprised of a number which must be positive and not have leading zeros, followed by a letter - session:
denoting the unit of time measurement. The table below describes the units of time and the associated letter. - expiration
- inactivity
- remember_me_duration
- regulation:
- ban_time
- find_time
- ntp:
- max_desync
- webauthn:
- timeout
The way this format works is you can either configure an integer or a string in the specific configuration areas. If you
supply an integer, it is considered a representation of seconds. If you supply a string, it parses the string in blocks
of quantities and units (number followed by a unit letter). For example `5h` indicates a quantity of 5 units of `h`.
While you can use multiple of these blocks in combination, ee suggest keeping it simple and use a single value.
## Duration Notation Format Unit Legend
| Unit | Associated Letter | | Unit | Associated Letter |
|:-----:|:---------------:| |:-------:|:-----------------:|
| Years | y | | Years | y |
| Months | M | | Months | M |
| Weeks | w | | Weeks | w |
@ -117,10 +133,13 @@ denoting the unit of time measurement. The table below describes the units of ti
| Minutes | m | | Minutes | m |
| Seconds | s | | Seconds | s |
Examples: ## Duration Notation Format Examples
* 1 hour and 30 minutes: 90m
* 1 day: 1d | Desired Value | Configuration Examples |
* 10 hours: 10h |:---------------------:|:-------------------------------------:|
| 1 hour and 30 minutes | `90m` or `1h30m` or `5400` or `5400s` |
| 1 day | `1d` or `24h` or `86400` or `86400s` |
| 10 hours | `10h` or `600m` or `9h60m` or `36000` |
# TLS Configuration # TLS Configuration

View File

@ -57,10 +57,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
} }
var ntpProvider *ntp.Provider ntpProvider := ntp.NewProvider(&config.NTP)
if config.NTP != nil {
ntpProvider = ntp.NewProvider(config.NTP)
}
clock := utils.RealClock{} clock := utils.RealClock{}
authorizer := authorization.NewAuthorizer(config) authorizer := authorization.NewAuthorizer(config)

View File

@ -4,8 +4,11 @@ import (
"fmt" "fmt"
"net/mail" "net/mail"
"reflect" "reflect"
"time"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/authelia/authelia/v4/internal/utils"
) )
// StringToMailAddressFunc decodes a string into a mail.Address. // StringToMailAddressFunc decodes a string into a mail.Address.
@ -33,3 +36,67 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
return *mailAddress, nil return *mailAddress, nil
} }
} }
// ToTimeDurationFunc converts string and integer types to a time.Duration.
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var (
ptr bool
)
switch f.Kind() {
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
// We only allow string and integer from kinds to match.
break
default:
return data, nil
}
typeTimeDuration := reflect.TypeOf(time.Hour)
if t.Kind() == reflect.Ptr {
if t.Elem() != typeTimeDuration {
return data, nil
}
ptr = true
} else if t != typeTimeDuration {
return data, nil
}
var duration time.Duration
switch {
case f.Kind() == reflect.String:
break
case f.Kind() == reflect.Int:
seconds := data.(int)
duration = time.Second * time.Duration(seconds)
case f.Kind() == reflect.Int32:
seconds := data.(int32)
duration = time.Second * time.Duration(seconds)
case f == typeTimeDuration:
duration = data.(time.Duration)
case f.Kind() == reflect.Int64:
seconds := data.(int64)
duration = time.Second * time.Duration(seconds)
}
if duration == 0 {
dataStr := data.(string)
if duration, err = utils.ParseDurationString(dataStr); err != nil {
return nil, err
}
}
if ptr {
return &duration, nil
}
return duration, nil
}
}

View File

@ -0,0 +1,270 @@
package configuration
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1h"
expected = time.Hour
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1y"
expected = time.Hour * 24 * 365
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1M"
expected = time.Hour * 24 * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1w"
expected = time.Hour * 24 * 7
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1d"
expected = time.Hour * 24
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "abc"
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.EqualError(t, err, "could not parse 'abc' as a duration")
assert.Nil(t, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.EqualError(t, err, "could not parse 'abc' as a duration")
assert.Nil(t, result)
}
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = 60
expected = time.Second * 60
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int32(120)
expected = time.Second * 120
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int64(30)
expected = time.Second * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = time.Second * 30
expected = time.Second * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int64(30)
to string
ptrTo *string
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
}
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = true
to string
ptrTo *string
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
}

View File

@ -43,9 +43,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
c := koanf.UnmarshalConf{ c := koanf.UnmarshalConf{
DecoderConfig: &mapstructure.DecoderConfig{ DecoderConfig: &mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc( DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","), mapstructure.StringToSliceHookFunc(","),
StringToMailAddressFunc(), StringToMailAddressFunc(),
ToTimeDurationFunc(),
), ),
Metadata: nil, Metadata: nil,
Result: o, Result: o,

View File

@ -14,8 +14,8 @@ type Configuration struct {
TOTP *TOTPConfiguration `koanf:"totp"` TOTP *TOTPConfiguration `koanf:"totp"`
DuoAPI *DuoAPIConfiguration `koanf:"duo_api"` DuoAPI *DuoAPIConfiguration `koanf:"duo_api"`
AccessControl AccessControlConfiguration `koanf:"access_control"` AccessControl AccessControlConfiguration `koanf:"access_control"`
NTP *NTPConfiguration `koanf:"ntp"` NTP NTPConfiguration `koanf:"ntp"`
Regulation *RegulationConfiguration `koanf:"regulation"` Regulation RegulationConfiguration `koanf:"regulation"`
Storage StorageConfiguration `koanf:"storage"` Storage StorageConfiguration `koanf:"storage"`
Notifier *NotifierConfiguration `koanf:"notifier"` Notifier *NotifierConfiguration `koanf:"notifier"`
Server ServerConfiguration `koanf:"server"` Server ServerConfiguration `koanf:"server"`

View File

@ -1,10 +1,14 @@
package schema package schema
import (
"time"
)
// NTPConfiguration represents the configuration related to ntp server. // NTPConfiguration represents the configuration related to ntp server.
type NTPConfiguration struct { type NTPConfiguration struct {
Address string `koanf:"address"` Address string `koanf:"address"`
Version int `koanf:"version"` Version int `koanf:"version"`
MaximumDesync string `koanf:"max_desync"` MaximumDesync time.Duration `koanf:"max_desync"`
DisableStartupCheck bool `koanf:"disable_startup_check"` DisableStartupCheck bool `koanf:"disable_startup_check"`
DisableFailure bool `koanf:"disable_failure"` DisableFailure bool `koanf:"disable_failure"`
} }
@ -13,5 +17,5 @@ type NTPConfiguration struct {
var DefaultNTPConfiguration = NTPConfiguration{ var DefaultNTPConfiguration = NTPConfiguration{
Address: "time.cloudflare.com:123", Address: "time.cloudflare.com:123",
Version: 4, Version: 4,
MaximumDesync: "3s", MaximumDesync: time.Second * 3,
} }

View File

@ -1,15 +1,19 @@
package schema package schema
import (
"time"
)
// RegulationConfiguration represents the configuration related to regulation. // RegulationConfiguration represents the configuration related to regulation.
type RegulationConfiguration struct { type RegulationConfiguration struct {
MaxRetries int `koanf:"max_retries"` MaxRetries int `koanf:"max_retries"`
FindTime string `koanf:"find_time,weak"` FindTime time.Duration `koanf:"find_time,weak"`
BanTime string `koanf:"ban_time,weak"` BanTime time.Duration `koanf:"ban_time,weak"`
} }
// DefaultRegulationConfiguration represents default configuration parameters for the regulator. // DefaultRegulationConfiguration represents default configuration parameters for the regulator.
var DefaultRegulationConfiguration = RegulationConfiguration{ var DefaultRegulationConfiguration = RegulationConfiguration{
MaxRetries: 3, MaxRetries: 3,
FindTime: "2m", FindTime: time.Minute * 2,
BanTime: "5m", BanTime: time.Minute * 5,
} }

View File

@ -1,5 +1,9 @@
package schema package schema
import (
"time"
)
// RedisNode Represents a Node. // RedisNode Represents a Node.
type RedisNode struct { type RedisNode struct {
Host string `koanf:"host"` Host string `koanf:"host"`
@ -35,17 +39,18 @@ type SessionConfiguration struct {
Domain string `koanf:"domain"` Domain string `koanf:"domain"`
SameSite string `koanf:"same_site"` SameSite string `koanf:"same_site"`
Secret string `koanf:"secret"` Secret string `koanf:"secret"`
Expiration string `koanf:"expiration"` Expiration time.Duration `koanf:"expiration"`
Inactivity string `koanf:"inactivity"` Inactivity time.Duration `koanf:"inactivity"`
RememberMeDuration string `koanf:"remember_me_duration"` RememberMeDuration time.Duration `koanf:"remember_me_duration"`
Redis *RedisSessionConfiguration `koanf:"redis"` Redis *RedisSessionConfiguration `koanf:"redis"`
} }
// DefaultSessionConfiguration is the default session configuration. // DefaultSessionConfiguration is the default session configuration.
var DefaultSessionConfiguration = SessionConfiguration{ var DefaultSessionConfiguration = SessionConfiguration{
Name: "authelia_session", Name: "authelia_session",
Expiration: "1h", Expiration: time.Hour,
Inactivity: "5m", Inactivity: time.Minute * 5,
RememberMeDuration: "1M", RememberMeDuration: time.Hour * 24 * 30,
SameSite: "lax", SameSite: "lax",
} }

View File

@ -35,7 +35,6 @@ const (
// Test constants. // Test constants.
const ( const (
testBadTimer = "-1"
testInvalidPolicy = "invalid" testInvalidPolicy = "invalid"
testJWTSecret = "a_secret" testJWTSecret = "a_secret"
testLDAPBaseDN = "base_dn" testLDAPBaseDN = "base_dn"
@ -186,12 +185,10 @@ const (
// NTP Error constants. // NTP Error constants.
const ( const (
errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'" errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'"
errFmtNTPMaxDesync = "ntp: option 'max_desync' can't be parsed: %w"
) )
// Session error constants. // Session error constants.
const ( const (
errFmtSessionCouldNotParseDuration = "session: option '%s' could not be parsed: %w"
errFmtSessionOptionRequired = "session: option '%s' is required" errFmtSessionOptionRequired = "session: option '%s' is required"
errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'" errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'"
errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'" errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'"
@ -206,7 +203,6 @@ const (
// Regulation Error Consts. // Regulation Error Consts.
const ( const (
errFmtRegulationParseDuration = "regulation: option '%s' could not be parsed: %w"
errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'" errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'"
) )

View File

@ -4,17 +4,10 @@ import (
"fmt" "fmt"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
) )
// ValidateNTP validates and update NTP configuration. // ValidateNTP validates and update NTP configuration.
func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) { func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) {
if config.NTP == nil {
config.NTP = &schema.DefaultNTPConfiguration
return
}
if config.NTP.Address == "" { if config.NTP.Address == "" {
config.NTP.Address = schema.DefaultNTPConfiguration.Address config.NTP.Address = schema.DefaultNTPConfiguration.Address
} }
@ -25,12 +18,7 @@ func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator
validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version)) validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version))
} }
if config.NTP.MaximumDesync == "" { if config.NTP.MaximumDesync == 0 {
config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync
} }
_, err := utils.ParseDurationString(config.NTP.MaximumDesync)
if err != nil {
validator.Push(fmt.Errorf(errFmtNTPMaxDesync, err))
}
} }

View File

@ -11,7 +11,7 @@ import (
func newDefaultNTPConfig() schema.Configuration { func newDefaultNTPConfig() schema.Configuration {
return schema.Configuration{ return schema.Configuration{
NTP: &schema.NTPConfiguration{}, NTP: schema.NTPConfiguration{},
} }
} }
@ -55,18 +55,6 @@ func TestShouldSetDefaultNtpDisableStartupCheck(t *testing.T) {
assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck) assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck)
} }
func TestShouldRaiseErrorOnMaximumDesyncString(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultNTPConfig()
config.NTP.MaximumDesync = "a second"
ValidateNTP(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "ntp: option 'max_desync' can't be parsed: could not parse 'a second' as a duration")
}
func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) { func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultNTPConfig() config := newDefaultNTPConfig()

View File

@ -4,36 +4,19 @@ import (
"fmt" "fmt"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
) )
// ValidateRegulation validates and update regulator configuration. // ValidateRegulation validates and update regulator configuration.
func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) { func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) {
if config.Regulation == nil { if config.Regulation.FindTime == 0 {
config.Regulation = &schema.DefaultRegulationConfiguration
return
}
if config.Regulation.FindTime == "" {
config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min. config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min.
} }
if config.Regulation.BanTime == "" { if config.Regulation.BanTime == 0 {
config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min. config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min.
} }
findTime, err := utils.ParseDurationString(config.Regulation.FindTime) if config.Regulation.FindTime > config.Regulation.BanTime {
if err != nil {
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "find_time", err))
}
banTime, err := utils.ParseDurationString(config.Regulation.BanTime)
if err != nil {
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "ban_time", err))
}
if findTime > banTime {
validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime)) validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime))
} }
} }

View File

@ -2,6 +2,7 @@ package validator
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -10,7 +11,7 @@ import (
func newDefaultRegulationConfig() schema.Configuration { func newDefaultRegulationConfig() schema.Configuration {
config := schema.Configuration{ config := schema.Configuration{
Regulation: &schema.RegulationConfiguration{}, Regulation: schema.RegulationConfiguration{},
} }
return config return config
@ -39,24 +40,11 @@ func TestShouldSetDefaultRegulationFindTime(t *testing.T) {
func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) { func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultRegulationConfig() config := newDefaultRegulationConfig()
config.Regulation.FindTime = "1m" config.Regulation.FindTime = time.Minute
config.Regulation.BanTime = "10s" config.Regulation.BanTime = time.Second * 10
ValidateRegulation(&config, validator) ValidateRegulation(&config, validator)
assert.Len(t, validator.Errors(), 1) assert.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'") assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'")
} }
func TestShouldRaiseErrorOnBadDurationStrings(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultRegulationConfig()
config.Regulation.FindTime = "a year"
config.Regulation.BanTime = "forever"
ValidateRegulation(&config, validator)
assert.Len(t, validator.Errors(), 2)
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' could not be parsed: could not parse 'a year' as a duration")
assert.EqualError(t, validator.Errors()[1], "regulation: option 'ban_time' could not be parsed: could not parse 'forever' as a duration")
}

View File

@ -27,22 +27,16 @@ func ValidateSession(config *schema.SessionConfiguration, validator *schema.Stru
} }
func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) { func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) {
if config.Expiration == "" { if config.Expiration <= 0 {
config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour. config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour.
} else if _, err := utils.ParseDurationString(config.Expiration); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "expiriation", err))
} }
if config.Inactivity == "" { if config.Inactivity <= 0 {
config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min. config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min.
} else if _, err := utils.ParseDurationString(config.Inactivity); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "inactivity", err))
} }
if config.RememberMeDuration == "" { if config.RememberMeDuration <= 0 {
config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month. config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month.
} else if _, err := utils.ParseDurationString(config.RememberMeDuration); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "remember_me_duration", err))
} }
if config.Domain == "" { if config.Domain == "" {

View File

@ -420,30 +420,21 @@ func TestShouldNotRaiseErrorWhenSameSiteSetCorrectly(t *testing.T) {
} }
} }
func TestShouldRaiseErrorWhenBadInactivityAndExpirationSet(t *testing.T) { func TestShouldSetDefaultWhenNegativeInactivityAndExpirationSet(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultSessionConfig() config := newDefaultSessionConfig()
config.Inactivity = testBadTimer config.Inactivity = -1
config.Expiration = testBadTimer config.Expiration = -1
config.RememberMeDuration = -1
ValidateSession(&config, validator) ValidateSession(&config, validator)
assert.False(t, validator.HasWarnings()) assert.Len(t, validator.Warnings(), 0)
assert.Len(t, validator.Errors(), 2) assert.Len(t, validator.Errors(), 0)
assert.EqualError(t, validator.Errors()[0], "session: option 'expiriation' could not be parsed: could not parse '-1' as a duration")
assert.EqualError(t, validator.Errors()[1], "session: option 'inactivity' could not be parsed: could not parse '-1' as a duration")
}
func TestShouldRaiseErrorWhenBadRememberMeDurationSet(t *testing.T) { assert.Equal(t, schema.DefaultSessionConfiguration.Inactivity, config.Inactivity)
validator := schema.NewStructValidator() assert.Equal(t, schema.DefaultSessionConfiguration.Expiration, config.Expiration)
config := newDefaultSessionConfig() assert.Equal(t, schema.DefaultSessionConfiguration.RememberMeDuration, config.RememberMeDuration)
config.RememberMeDuration = "1 year"
ValidateSession(&config, validator)
assert.False(t, validator.HasWarnings())
assert.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "session: option 'remember_me_duration' could not be parsed: could not parse '1 year' as a duration")
} }
func TestShouldSetDefaultRememberMeDuration(t *testing.T) { func TestShouldSetDefaultRememberMeDuration(t *testing.T) {

View File

@ -1,6 +1,8 @@
package handlers package handlers
import ( import (
"time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -56,7 +58,7 @@ const (
) )
const ( const (
testInactivity = "10" testInactivity = time.Second * 10
testRedirectionURL = "http://redirection.local" testRedirectionURL = "http://redirection.local"
testUsername = "john" testUsername = "john"
) )

View File

@ -602,7 +602,7 @@ func TestShouldDestroySessionWhenInactiveForTooLongUsingDurationNotation(t *test
clock := mocks.TestingClock{} clock := mocks.TestingClock{}
clock.Set(time.Now()) clock.Set(time.Now())
mock.Ctx.Configuration.Session.Inactivity = "10s" mock.Ctx.Configuration.Session.Inactivity = time.Second * 10
// Reload the session provider since the configuration is indirect. // Reload the session provider since the configuration is indirect.
mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil) mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil)
assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity) assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity)

View File

@ -8,7 +8,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
) )
// NewProvider instantiate a ntp provider given a configuration. // NewProvider instantiate a ntp provider given a configuration.
@ -59,11 +58,9 @@ func (p *Provider) StartupCheck() (err error) {
return nil return nil
} }
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
ntpTime := ntpPacketToTime(resp) ntpTime := ntpPacketToTime(resp)
if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result { if result := ntpIsOffsetTooLarge(p.config.MaximumDesync, now, ntpTime); result {
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server") return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
} }

View File

@ -2,6 +2,7 @@ package ntp
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -11,18 +12,17 @@ import (
func TestShouldCheckNTP(t *testing.T) { func TestShouldCheckNTP(t *testing.T) {
config := &schema.Configuration{ config := &schema.Configuration{
NTP: &schema.NTPConfiguration{ NTP: schema.NTPConfiguration{
Address: "time.cloudflare.com:123", Address: "time.cloudflare.com:123",
Version: 4, Version: 4,
MaximumDesync: "3s", MaximumDesync: time.Second * 3,
DisableStartupCheck: false,
}, },
} }
sv := schema.NewStructValidator() sv := schema.NewStructValidator()
validator.ValidateNTP(config, sv) validator.ValidateNTP(config, sv)
ntp := NewProvider(config.NTP) ntp := NewProvider(&config.NTP)
assert.NoError(t, ntp.StartupCheck()) assert.NoError(t, ntp.StartupCheck())
} }

View File

@ -2,7 +2,6 @@ package regulation
import ( import (
"context" "context"
"fmt"
"net" "net"
"time" "time"
@ -13,33 +12,13 @@ import (
) )
// NewRegulator create a regulator instance. // NewRegulator create a regulator instance.
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
regulator := &Regulator{storageProvider: provider} return &Regulator{
regulator.clock = clock enabled: config.MaxRetries > 0,
storageProvider: provider,
if configuration != nil { clock: clock,
findTime, err := utils.ParseDurationString(configuration.FindTime) config: config,
if err != nil {
panic(err)
} }
banTime, err := utils.ParseDurationString(configuration.BanTime)
if err != nil {
panic(err)
}
if findTime > banTime {
panic(fmt.Errorf("find_time cannot be greater than ban_time"))
}
// Set regulator enabled only if MaxRetries is not 0.
regulator.enabled = configuration.MaxRetries > 0
regulator.maxRetries = configuration.MaxRetries
regulator.findTime = findTime
regulator.banTime = banTime
}
return regulator
} }
// Mark an authentication attempt. // Mark an authentication attempt.
@ -65,15 +44,15 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
return time.Time{}, nil return time.Time{}, nil
} }
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0) attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
if err != nil { if err != nil {
return time.Time{}, nil return time.Time{}, nil
} }
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries) latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.config.MaxRetries)
for _, attempt := range attempts { for _, attempt := range attempts {
if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries { if attempt.Successful || len(latestFailedAttempts) >= r.config.MaxRetries {
// We stop appending failed attempts once we find the first successful attempts or we reach // We stop appending failed attempts once we find the first successful attempts or we reach
// the configured number of retries, meaning the user is already banned. // the configured number of retries, meaning the user is already banned.
break break
@ -84,17 +63,17 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
// If the number of failed attempts within the ban time is less than the max number of retries // If the number of failed attempts within the ban time is less than the max number of retries
// then the user is not banned. // then the user is not banned.
if len(latestFailedAttempts) < r.maxRetries { if len(latestFailedAttempts) < r.config.MaxRetries {
return time.Time{}, nil return time.Time{}, nil
} }
// Now we compute the time between the latest attempt and the MaxRetry-th one. If it's // Now we compute the time between the latest attempt and the MaxRetry-th one. If it's
// within the FindTime then it means that the user has been banned. // within the FindTime then it means that the user has been banned.
durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub( durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub(
latestFailedAttempts[r.maxRetries-1].Time) latestFailedAttempts[r.config.MaxRetries-1].Time)
if durationBetweenLatestAttempts < r.findTime { if durationBetweenLatestAttempts < r.config.FindTime {
bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime) bannedUntil := latestFailedAttempts[0].Time.Add(r.config.BanTime)
return bannedUntil, ErrUserIsBanned return bannedUntil, ErrUserIsBanned
} }

View File

@ -21,7 +21,7 @@ type RegulatorSuite struct {
ctx context.Context ctx context.Context
ctrl *gomock.Controller ctrl *gomock.Controller
storageMock *mocks.MockStorage storageMock *mocks.MockStorage
configuration schema.RegulationConfiguration config schema.RegulationConfiguration
clock mocks.TestingClock clock mocks.TestingClock
} }
@ -30,10 +30,10 @@ func (s *RegulatorSuite) SetupTest() {
s.storageMock = mocks.NewMockStorage(s.ctrl) s.storageMock = mocks.NewMockStorage(s.ctrl)
s.ctx = context.Background() s.ctx = context.Background()
s.configuration = schema.RegulationConfiguration{ s.config = schema.RegulationConfiguration{
MaxRetries: 3, MaxRetries: 3,
BanTime: "180", BanTime: time.Second * 180,
FindTime: "30", FindTime: time.Second * 30,
} }
s.clock.Set(time.Now()) s.clock.Set(time.Now())
} }
@ -55,7 +55,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
@ -86,7 +86,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
@ -122,7 +122,7 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
@ -155,7 +155,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
@ -179,7 +179,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
@ -211,7 +211,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
@ -247,7 +247,7 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
@ -283,24 +283,24 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
// Check Disabled Functionality. // Check Disabled Functionality.
configuration := schema.RegulationConfiguration{ config := schema.RegulationConfiguration{
MaxRetries: 0, MaxRetries: 0,
FindTime: "180", FindTime: time.Second * 180,
BanTime: "180", BanTime: time.Second * 180,
} }
regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock) regulator := regulation.NewRegulator(config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
// Check Enabled Functionality. // Check Enabled Functionality.
configuration = schema.RegulationConfiguration{ config = schema.RegulationConfiguration{
MaxRetries: 1, MaxRetries: 1,
FindTime: "180", FindTime: time.Second * 180,
BanTime: "180", BanTime: time.Second * 180,
} }
regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock) regulator = regulation.NewRegulator(config, s.storageMock, &s.clock)
_, err = regulator.Regulate(s.ctx, "john") _, err = regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
} }

View File

@ -1,8 +1,7 @@
package regulation package regulation
import ( import (
"time" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -11,12 +10,8 @@ import (
type Regulator struct { type Regulator struct {
// Is the regulation enabled. // Is the regulation enabled.
enabled bool enabled bool
// The number of failed authentication attempt before banning the user.
maxRetries int config schema.RegulationConfiguration
// If a user does the max number of retries within that duration, she will be banned.
findTime time.Duration
// If a user has been banned, this duration is the timelapse during which the user is banned.
banTime time.Duration
storageProvider storage.RegulatorProvider storageProvider storage.RegulatorProvider

View File

@ -28,7 +28,7 @@ var assets embed.FS
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers) autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != "0") rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != -1)
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword) resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
duoSelfEnrollment := f duoSelfEnrollment := f

View File

@ -1,8 +1,12 @@
package session package session
import (
"time"
)
const ( const (
testDomain = "example.com" testDomain = "example.com"
testExpiration = "40" testExpiration = time.Second * 40
testName = "my_session" testName = "my_session"
testUsername = "john" testUsername = "john"
) )

View File

@ -12,7 +12,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
) )
// Provider a session provider. // Provider a session provider.
@ -23,38 +22,29 @@ type Provider struct {
} }
// NewProvider instantiate a session provider given a configuration. // NewProvider instantiate a session provider given a configuration.
func NewProvider(configuration schema.SessionConfiguration, certPool *x509.CertPool) *Provider { func NewProvider(config schema.SessionConfiguration, certPool *x509.CertPool) *Provider {
providerConfig := NewProviderConfig(configuration, certPool) c := NewProviderConfig(config, certPool)
provider := new(Provider) provider := new(Provider)
provider.sessionHolder = fasthttpsession.New(providerConfig.config) provider.sessionHolder = fasthttpsession.New(c.config)
logger := logging.Logger() logger := logging.Logger()
duration, err := utils.ParseDurationString(configuration.RememberMeDuration) provider.Inactivity, provider.RememberMe = config.Inactivity, config.RememberMeDuration
if err != nil {
logger.Fatal(err)
}
provider.RememberMe = duration var (
providerImpl fasthttpsession.Provider
duration, err = utils.ParseDurationString(configuration.Inactivity) err error
if err != nil { )
logger.Fatal(err)
}
provider.Inactivity = duration
var providerImpl fasthttpsession.Provider
switch { switch {
case providerConfig.redisConfig != nil: case c.redisConfig != nil:
providerImpl, err = redis.New(*providerConfig.redisConfig) providerImpl, err = redis.New(*c.redisConfig)
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)
} }
case providerConfig.redisSentinelConfig != nil: case c.redisSentinelConfig != nil:
providerImpl, err = redis.NewFailoverCluster(*providerConfig.redisSentinelConfig) providerImpl, err = redis.NewFailoverCluster(*c.redisSentinelConfig)
if err != nil { if err != nil {
logger.Fatal(err) logger.Fatal(err)
} }

View File

@ -17,10 +17,10 @@ import (
) )
// NewProviderConfig creates a configuration for creating the session provider. // NewProviderConfig creates a configuration for creating the session provider.
func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig { func NewProviderConfig(config schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
config := session.NewDefaultConfig() c := session.NewDefaultConfig()
config.SessionIDGeneratorFunc = func() []byte { c.SessionIDGeneratorFunc = func() []byte {
bytes := make([]byte, 32) bytes := make([]byte, 32)
_, _ = rand.Read(bytes) _, _ = rand.Read(bytes)
@ -33,30 +33,30 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
} }
// Override the cookie name. // Override the cookie name.
config.CookieName = configuration.Name c.CookieName = config.Name
// Set the cookie to the given domain. // Set the cookie to the given domain.
config.Domain = configuration.Domain c.Domain = config.Domain
// Set the cookie SameSite option. // Set the cookie SameSite option.
switch configuration.SameSite { switch config.SameSite {
case "strict": case "strict":
config.CookieSameSite = fasthttp.CookieSameSiteStrictMode c.CookieSameSite = fasthttp.CookieSameSiteStrictMode
case "none": case "none":
config.CookieSameSite = fasthttp.CookieSameSiteNoneMode c.CookieSameSite = fasthttp.CookieSameSiteNoneMode
case "lax": case "lax":
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
default: default:
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
} }
// Only serve the header over HTTPS. // Only serve the header over HTTPS.
config.Secure = true c.Secure = true
// Ignore the error as it will be handled by validator. // Ignore the error as it will be handled by validator.
config.Expiration, _ = utils.ParseDurationString(configuration.Expiration) c.Expiration = config.Expiration
config.IsSecureFunc = func(*fasthttp.RequestCtx) bool { c.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
return true return true
} }
@ -68,23 +68,23 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
// If redis configuration is provided, then use the redis provider. // If redis configuration is provided, then use the redis provider.
switch { switch {
case configuration.Redis != nil: case config.Redis != nil:
serializer := NewEncryptingSerializer(configuration.Secret) serializer := NewEncryptingSerializer(config.Secret)
var tlsConfig *tls.Config var tlsConfig *tls.Config
if configuration.Redis.TLS != nil { if config.Redis.TLS != nil {
tlsConfig = utils.NewTLSConfig(configuration.Redis.TLS, tls.VersionTLS12, certPool) tlsConfig = utils.NewTLSConfig(config.Redis.TLS, tls.VersionTLS12, certPool)
} }
if configuration.Redis.HighAvailability != nil && configuration.Redis.HighAvailability.SentinelName != "" { if config.Redis.HighAvailability != nil && config.Redis.HighAvailability.SentinelName != "" {
addrs := make([]string, 0) addrs := make([]string, 0)
if configuration.Redis.Host != "" { if config.Redis.Host != "" {
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(configuration.Redis.Host), configuration.Redis.Port)) addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(config.Redis.Host), config.Redis.Port))
} }
for _, node := range configuration.Redis.HighAvailability.Nodes { for _, node := range config.Redis.HighAvailability.Nodes {
addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port) addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port)
if !utils.IsStringInSlice(addr, addrs) { if !utils.IsStringInSlice(addr, addrs) {
addrs = append(addrs, addr) addrs = append(addrs, addr)
@ -94,17 +94,17 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
providerName = "redis-sentinel" providerName = "redis-sentinel"
redisSentinelConfig = &redis.FailoverConfig{ redisSentinelConfig = &redis.FailoverConfig{
Logger: &redisLogger{logger: logging.Logger()}, Logger: &redisLogger{logger: logging.Logger()},
MasterName: configuration.Redis.HighAvailability.SentinelName, MasterName: config.Redis.HighAvailability.SentinelName,
SentinelAddrs: addrs, SentinelAddrs: addrs,
SentinelUsername: configuration.Redis.HighAvailability.SentinelUsername, SentinelUsername: config.Redis.HighAvailability.SentinelUsername,
SentinelPassword: configuration.Redis.HighAvailability.SentinelPassword, SentinelPassword: config.Redis.HighAvailability.SentinelPassword,
RouteByLatency: configuration.Redis.HighAvailability.RouteByLatency, RouteByLatency: config.Redis.HighAvailability.RouteByLatency,
RouteRandomly: configuration.Redis.HighAvailability.RouteRandomly, RouteRandomly: config.Redis.HighAvailability.RouteRandomly,
Username: configuration.Redis.Username, Username: config.Redis.Username,
Password: configuration.Redis.Password, Password: config.Redis.Password,
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: configuration.Redis.MaximumActiveConnections, PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: configuration.Redis.MinimumIdleConnections, MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300, IdleTimeout: 300,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
KeyPrefix: "authelia-session", KeyPrefix: "authelia-session",
@ -115,36 +115,36 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
var addr string var addr string
if configuration.Redis.Port == 0 { if config.Redis.Port == 0 {
network = "unix" network = "unix"
addr = configuration.Redis.Host addr = config.Redis.Host
} else { } else {
addr = fmt.Sprintf("%s:%d", configuration.Redis.Host, configuration.Redis.Port) addr = fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
} }
redisConfig = &redis.Config{ redisConfig = &redis.Config{
Logger: newRedisLogger(), Logger: newRedisLogger(),
Network: network, Network: network,
Addr: addr, Addr: addr,
Username: configuration.Redis.Username, Username: config.Redis.Username,
Password: configuration.Redis.Password, Password: config.Redis.Password,
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: configuration.Redis.MaximumActiveConnections, PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: configuration.Redis.MinimumIdleConnections, MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300, IdleTimeout: 300,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
KeyPrefix: "authelia-session", KeyPrefix: "authelia-session",
} }
} }
config.EncodeFunc = serializer.Encode c.EncodeFunc = serializer.Encode
config.DecodeFunc = serializer.Decode c.DecodeFunc = serializer.Decode
default: default:
providerName = "memory" providerName = "memory"
} }
return ProviderConfig{ return ProviderConfig{
config, c,
redisConfig, redisConfig,
redisSentinelConfig, redisSentinelConfig,
providerName, providerName,

View File

@ -53,7 +53,25 @@ const (
) )
var ( var (
reDuration = regexp.MustCompile(`^(?P<Duration>[1-9]\d*?)(?P<Unit>[smhdwMy])?$`) standardDurationUnits = []string{"ns", "us", "µs", "μs", "ms", "s", "m", "h"}
reDurationSeconds = regexp.MustCompile(`^\d+$`)
reDurationStandard = regexp.MustCompile(`(?P<Duration>[1-9]\d*?)(?P<Unit>[^\d\s]+)`)
)
// Duration unit types.
const (
DurationUnitDays = "d"
DurationUnitWeeks = "w"
DurationUnitMonths = "M"
DurationUnitYears = "y"
)
// Number of hours in particular measurements of time.
const (
HoursInDay = 24
HoursInWeek = HoursInDay * 7
HoursInMonth = HoursInDay * 30
HoursInYear = HoursInDay * 365
) )
var ( var (

View File

@ -6,46 +6,64 @@ import (
"time" "time"
) )
// ParseDurationString parses a string to a duration // StandardizeDurationString converts units of time that stdlib is unaware of to hours.
// Duration notations are an integer followed by a unit func StandardizeDurationString(input string) (output string, err error) {
// Units are s = second, m = minute, d = day, w = week, M = month, y = year if input == "" {
// Example 1y is the same as 1 year. return "0s", nil
func ParseDurationString(input string) (time.Duration, error) { }
var duration time.Duration
matches := reDuration.FindStringSubmatch(input) matches := reDurationStandard.FindAllStringSubmatch(input, -1)
if len(matches) == 0 {
return "", fmt.Errorf("could not parse '%s' as a duration", input)
}
var d int
for _, match := range matches {
if d, err = strconv.Atoi(match[1]); err != nil {
return "", fmt.Errorf("could not parse the numeric portion of '%s' in duration string '%s': %w", match[0], input, err)
}
unit := match[2]
switch { switch {
case len(matches) == 3 && matches[2] != "": case IsStringInSlice(unit, standardDurationUnits):
d, _ := strconv.Atoi(matches[1]) output += fmt.Sprintf("%d%s", d, unit)
case unit == DurationUnitDays:
switch matches[2] { output += fmt.Sprintf("%dh", d*HoursInDay)
case "y": case unit == DurationUnitWeeks:
duration = time.Duration(d) * Year output += fmt.Sprintf("%dh", d*HoursInWeek)
case "M": case unit == DurationUnitMonths:
duration = time.Duration(d) * Month output += fmt.Sprintf("%dh", d*HoursInMonth)
case "w": case unit == DurationUnitYears:
duration = time.Duration(d) * Week output += fmt.Sprintf("%dh", d*HoursInYear)
case "d": default:
duration = time.Duration(d) * Day return "", fmt.Errorf("could not parse the units portion of '%s' in duration string '%s': the unit '%s' is not valid", match[0], input, unit)
case "h":
duration = time.Duration(d) * Hour
case "m":
duration = time.Duration(d) * time.Minute
case "s":
duration = time.Duration(d) * time.Second
} }
case input == "0" || len(matches) == 3:
seconds, err := strconv.Atoi(input)
if err != nil {
return 0, fmt.Errorf("could not parse '%s' as a duration: %w", input, err)
} }
duration = time.Duration(seconds) * time.Second return output, nil
case input != "":
// Throw this error if input is anything other than a blank string, blank string will default to a duration of nothing.
return 0, fmt.Errorf("could not parse '%s' as a duration", input)
} }
return duration, nil // ParseDurationString standardizes a duration string with StandardizeDurationString then uses time.ParseDuration to
// convert it into a time.Duration.
func ParseDurationString(input string) (duration time.Duration, err error) {
if reDurationSeconds.MatchString(input) {
var seconds int
if seconds, err = strconv.Atoi(input); err != nil {
return 0, nil
}
return time.Second * time.Duration(seconds), nil
}
var out string
if out, err = StandardizeDurationString(input); err != nil {
return 0, err
}
return time.ParseDuration(out)
} }

View File

@ -7,66 +7,112 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestShouldParseDurationString(t *testing.T) { func TestParseDurationString_ShouldParseDurationString(t *testing.T) {
duration, err := ParseDurationString("1h") duration, err := ParseDurationString("1h")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 60*time.Minute, duration) assert.Equal(t, 60*time.Minute, duration)
} }
func TestShouldParseDurationStringAllUnits(t *testing.T) { func TestParseDurationString_ShouldParseBlankString(t *testing.T) {
duration, err := ParseDurationString("1y") duration, err := ParseDurationString("")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, Year, duration) assert.Equal(t, time.Second*0, duration)
}
func TestParseDurationString_ShouldParseDurationStringAllUnits(t *testing.T) {
duration, err := ParseDurationString("1y")
assert.NoError(t, err)
assert.Equal(t, time.Hour*24*365, duration)
duration, err = ParseDurationString("1M") duration, err = ParseDurationString("1M")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, Month, duration) assert.Equal(t, time.Hour*24*30, duration)
duration, err = ParseDurationString("1w") duration, err = ParseDurationString("1w")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, Week, duration) assert.Equal(t, time.Hour*24*7, duration)
duration, err = ParseDurationString("1d") duration, err = ParseDurationString("1d")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, Day, duration) assert.Equal(t, time.Hour*24, duration)
duration, err = ParseDurationString("1h") duration, err = ParseDurationString("1h")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, Hour, duration) assert.Equal(t, time.Hour, duration)
duration, err = ParseDurationString("1s") duration, err = ParseDurationString("1s")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, time.Second, duration) assert.Equal(t, time.Second, duration)
} }
func TestShouldParseSecondsString(t *testing.T) { func TestParseDurationString_ShouldParseSecondsString(t *testing.T) {
duration, err := ParseDurationString("100") duration, err := ParseDurationString("100")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 100*time.Second, duration) assert.Equal(t, 100*time.Second, duration)
} }
func TestShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) { func TestParseDurationString_ShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) {
duration, err := ParseDurationString("h1") duration, err := ParseDurationString("h1")
assert.EqualError(t, err, "could not parse 'h1' as a duration") assert.EqualError(t, err, "could not parse 'h1' as a duration")
assert.Equal(t, time.Duration(0), duration) assert.Equal(t, time.Duration(0), duration)
} }
func TestShouldNotParseBadDurationString(t *testing.T) { func TestParseDurationString_ShouldNotParseBadDurationString(t *testing.T) {
duration, err := ParseDurationString("10x") duration, err := ParseDurationString("10x")
assert.EqualError(t, err, "could not parse '10x' as a duration")
assert.EqualError(t, err, "could not parse the units portion of '10x' in duration string '10x': the unit 'x' is not valid")
assert.Equal(t, time.Duration(0), duration) assert.Equal(t, time.Duration(0), duration)
} }
func TestShouldNotParseDurationStringWithMultiValueUnits(t *testing.T) { func TestParseDurationString_ShouldParseDurationStringWithMultiValueUnits(t *testing.T) {
duration, err := ParseDurationString("10ms") duration, err := ParseDurationString("10ms")
assert.EqualError(t, err, "could not parse '10ms' as a duration")
assert.Equal(t, time.Duration(0), duration) assert.NoError(t, err)
assert.Equal(t, time.Duration(10)*time.Millisecond, duration)
} }
func TestShouldNotParseDurationStringWithLeadingZero(t *testing.T) { func TestParseDurationString_ShouldParseDurationStringWithLeadingZero(t *testing.T) {
duration, err := ParseDurationString("005h") duration, err := ParseDurationString("005h")
assert.EqualError(t, err, "could not parse '005h' as a duration")
assert.Equal(t, time.Duration(0), duration) assert.NoError(t, err)
assert.Equal(t, time.Duration(5)*time.Hour, duration)
}
func TestParseDurationString_ShouldParseMultiUnitValues(t *testing.T) {
duration, err := ParseDurationString("1d3w10ms")
assert.NoError(t, err)
assert.Equal(t,
(time.Hour*time.Duration(24))+
(time.Hour*time.Duration(24)*time.Duration(7)*time.Duration(3))+
(time.Millisecond*time.Duration(10)), duration)
}
func TestParseDurationString_ShouldParseDuplicateUnitValues(t *testing.T) {
duration, err := ParseDurationString("1d4d2d")
assert.NoError(t, err)
assert.Equal(t,
(time.Hour*time.Duration(24))+
(time.Hour*time.Duration(24)*time.Duration(4))+
(time.Hour*time.Duration(24)*time.Duration(2)), duration)
}
func TestStandardizeDurationString_ShouldParseStringWithSpaces(t *testing.T) {
result, err := StandardizeDurationString("1d 1h 20m")
assert.NoError(t, err)
assert.Equal(t, result, "24h1h20m")
} }
func TestShouldTimeIntervalsMakeSense(t *testing.T) { func TestShouldTimeIntervalsMakeSense(t *testing.T) {