refactor(configuration): decode_hooks blackbox and better testing (#3097)

pull/3105/head
James Elliott 2022-04-03 22:44:52 +10:00 committed by GitHub
parent bfd5d66ed8
commit 7230db7cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 817 additions and 546 deletions

View File

@ -18,13 +18,21 @@ const (
constWindows = "windows" constWindows = "windows"
) )
var (
errNoValidator = errors.New("no validator provided")
errNoSources = errors.New("no sources provided")
errDecodeNonPtrMustHaveValue = errors.New("must have a non-empty value")
)
const ( const (
errFmtSecretAlreadyDefined = "secrets: error loading secret into key '%s': it's already defined in other " + errFmtSecretAlreadyDefined = "secrets: error loading secret into key '%s': it's already defined in other " +
"configuration sources" "configuration sources"
errFmtSecretIOIssue = "secrets: error loading secret path %s into key '%s': %v" errFmtSecretIOIssue = "secrets: error loading secret path %s into key '%s': %v"
errFmtGenerateConfiguration = "error occurred generating configuration: %+v" errFmtGenerateConfiguration = "error occurred generating configuration: %+v"
errFmtDecodeHookCouldNotParse = "could not decode '%s' to a %s: %w"
errFmtDecodeHookCouldNotParseEmptyValue = "could not decode an empty value to a %s: %w"
) )
var secretSuffixes = []string{"key", "secret", "password", "token"} var secretSuffixes = []string{"key", "secret", "password", "token"}
var errNoSources = errors.New("no sources provided")
var errNoValidator = errors.New("no validator provided")

View File

@ -13,32 +13,53 @@ import (
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// StringToMailAddressHookFunc decodes a string into a mail.Address. // StringToMailAddressHookFunc decodes a string into a mail.Address or *mail.Address.
func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType { func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) { var ptr bool
if f.Kind() != reflect.String {
return data, nil
}
kindStr := "mail.Address (RFC5322)"
if t.Kind() == reflect.Ptr {
ptr = true
kindStr = "*" + kindStr
}
expectedType := reflect.TypeOf(mail.Address{})
if ptr && t.Elem() != expectedType {
return data, nil
} else if !ptr && t != expectedType {
return data, nil return data, nil
} }
dataStr := data.(string) dataStr := data.(string)
if dataStr == "" { var result *mail.Address
if dataStr != "" {
if result, err = mail.ParseAddress(dataStr); err != nil {
return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
}
}
if ptr {
return result, nil
}
if result == nil {
return mail.Address{}, nil return mail.Address{}, nil
} }
var ( return *result, nil
parsedAddress *mail.Address
)
if parsedAddress, err = mail.ParseAddress(dataStr); err != nil {
return nil, fmt.Errorf("could not parse '%s' as a RFC5322 address: %w", dataStr, err)
}
return *parsedAddress, nil
} }
} }
// StringToURLHookFunc converts string types into a url.URL. // StringToURLHookFunc converts string types into a url.URL or *url.URL.
func StringToURLHookFunc() mapstructure.DecodeHookFuncType { func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var ptr bool var ptr bool
@ -47,37 +68,40 @@ func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return data, nil return data, nil
} }
ptr = t.Kind() == reflect.Ptr kindStr := "url.URL"
typeURL := reflect.TypeOf(url.URL{}) if t.Kind() == reflect.Ptr {
ptr = true
kindStr = "*" + kindStr
}
if ptr && t.Elem() != typeURL { expectedType := reflect.TypeOf(url.URL{})
if ptr && t.Elem() != expectedType {
return data, nil return data, nil
} else if !ptr && t != typeURL { } else if !ptr && t != expectedType {
return data, nil return data, nil
} }
dataStr := data.(string) dataStr := data.(string)
var parsedURL *url.URL var result *url.URL
// Return an empty URL if there is an empty string.
if dataStr != "" { if dataStr != "" {
if parsedURL, err = url.Parse(dataStr); err != nil { if result, err = url.Parse(dataStr); err != nil {
return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err) return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
} }
} }
if ptr { if ptr {
return parsedURL, nil return result, nil
} }
// Return an empty URL if there is an empty string. if result == nil {
if parsedURL == nil {
return url.URL{}, nil return url.URL{}, nil
} }
return *parsedURL, nil return *result, nil
} }
} }
@ -94,48 +118,51 @@ func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
return data, nil return data, nil
} }
typeTimeDuration := reflect.TypeOf(time.Hour) kindStr := "time.Duration"
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
if t.Elem() != typeTimeDuration {
return data, nil
}
ptr = true ptr = true
} else if t != typeTimeDuration { kindStr = "*" + kindStr
}
expectedType := reflect.TypeOf(time.Duration(0))
if ptr && t.Elem() != expectedType {
return data, nil
} else if !ptr && t != expectedType {
return data, nil return data, nil
} }
var duration time.Duration var result time.Duration
switch { switch {
case f.Kind() == reflect.String: case f.Kind() == reflect.String:
dataStr := data.(string) dataStr := data.(string)
if duration, err = utils.ParseDurationString(dataStr); err != nil { if result, err = utils.ParseDurationString(dataStr); err != nil {
return nil, err return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
} }
case f.Kind() == reflect.Int: case f.Kind() == reflect.Int:
seconds := data.(int) seconds := data.(int)
duration = time.Second * time.Duration(seconds) result = time.Second * time.Duration(seconds)
case f.Kind() == reflect.Int32: case f.Kind() == reflect.Int32:
seconds := data.(int32) seconds := data.(int32)
duration = time.Second * time.Duration(seconds) result = time.Second * time.Duration(seconds)
case f == typeTimeDuration: case f == expectedType:
duration = data.(time.Duration) result = data.(time.Duration)
case f.Kind() == reflect.Int64: case f.Kind() == reflect.Int64:
seconds := data.(int64) seconds := data.(int64)
duration = time.Second * time.Duration(seconds) result = time.Second * time.Duration(seconds)
} }
if ptr { if ptr {
return &duration, nil return &result, nil
} }
return duration, nil return result, nil
} }
} }
@ -148,27 +175,39 @@ func StringToRegexpFunc() mapstructure.DecodeHookFuncType {
return data, nil return data, nil
} }
ptr = t.Kind() == reflect.Ptr kindStr := "regexp.Regexp"
typeRegexp := reflect.TypeOf(regexp.Regexp{}) if t.Kind() == reflect.Ptr {
ptr = true
kindStr = "*" + kindStr
}
if ptr && t.Elem() != typeRegexp { expectedType := reflect.TypeOf(regexp.Regexp{})
if ptr && t.Elem() != expectedType {
return data, nil return data, nil
} else if !ptr && t != typeRegexp { } else if !ptr && t != expectedType {
return data, nil return data, nil
} }
regexStr := data.(string) dataStr := data.(string)
pattern, err := regexp.Compile(regexStr) var result *regexp.Regexp
if err != nil {
return nil, fmt.Errorf("could not parse '%s' as regexp: %w", regexStr, err) if dataStr != "" {
if result, err = regexp.Compile(dataStr); err != nil {
return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err)
}
} }
if ptr { if ptr {
return pattern, nil return result, nil
} }
return *pattern, nil if result == nil {
return nil, fmt.Errorf(errFmtDecodeHookCouldNotParseEmptyValue, kindStr, errDecodeNonPtrMustHaveValue)
}
return *result, nil
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -236,7 +236,7 @@ func TestShouldRaiseErrOnInvalidNotifierSMTPSender(t *testing.T) {
require.Len(t, val.Errors(), 1) require.Len(t, val.Errors(), 1)
assert.Len(t, val.Warnings(), 0) assert.Len(t, val.Warnings(), 0)
assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'notifier.smtp.sender': could not parse 'admin' as a RFC5322 address: mail: missing '@' or angle-addr") assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'notifier.smtp.sender': could not decode 'admin' to a mail.Address (RFC5322): mail: missing '@' or angle-addr")
} }
func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) { func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) {
@ -343,7 +343,7 @@ func TestShouldErrOnParseInvalidRegex(t *testing.T) {
require.Len(t, val.Errors(), 1) require.Len(t, val.Errors(), 1)
assert.Len(t, val.Warnings(), 0) assert.Len(t, val.Warnings(), 0)
assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'access_control.rules[0].domain_regex[0]': could not parse '^\\K(public|public2).example.com$' as regexp: error parsing regexp: invalid escape sequence: `\\K`") assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'access_control.rules[0].domain_regex[0]': could not decode '^\\K(public|public2).example.com$' to a regexp.Regexp: error parsing regexp: invalid escape sequence: `\\K`")
} }
func TestShouldNotReadConfigurationOnFSAccessDenied(t *testing.T) { func TestShouldNotReadConfigurationOnFSAccessDenied(t *testing.T) {