diff --git a/internal/configuration/const.go b/internal/configuration/const.go index e54bcfcbf..097da815c 100644 --- a/internal/configuration/const.go +++ b/internal/configuration/const.go @@ -18,13 +18,21 @@ const ( constWindows = "windows" ) +var ( + errNoValidator = errors.New("no validator provided") + errNoSources = errors.New("no sources provided") + + errDecodeNonPtrMustHaveValue = errors.New("must have a non-empty value") +) + const ( errFmtSecretAlreadyDefined = "secrets: error loading secret into key '%s': it's already defined in other " + "configuration sources" errFmtSecretIOIssue = "secrets: error loading secret path %s into key '%s': %v" errFmtGenerateConfiguration = "error occurred generating configuration: %+v" + + errFmtDecodeHookCouldNotParse = "could not decode '%s' to a %s: %w" + errFmtDecodeHookCouldNotParseEmptyValue = "could not decode an empty value to a %s: %w" ) var secretSuffixes = []string{"key", "secret", "password", "token"} -var errNoSources = errors.New("no sources provided") -var errNoValidator = errors.New("no validator provided") diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go index 4258513a4..e7fa14988 100644 --- a/internal/configuration/decode_hooks.go +++ b/internal/configuration/decode_hooks.go @@ -13,32 +13,53 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -// StringToMailAddressHookFunc decodes a string into a mail.Address. +// StringToMailAddressHookFunc decodes a string into a mail.Address or *mail.Address. func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { - if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) { + var ptr bool + + if f.Kind() != reflect.String { + return data, nil + } + + kindStr := "mail.Address (RFC5322)" + + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } + + expectedType := reflect.TypeOf(mail.Address{}) + + if ptr && t.Elem() != expectedType { + return data, nil + } else if !ptr && t != expectedType { return data, nil } dataStr := data.(string) - if dataStr == "" { + var result *mail.Address + + if dataStr != "" { + if result, err = mail.ParseAddress(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) + } + } + + if ptr { + return result, nil + } + + if result == nil { return mail.Address{}, nil } - var ( - parsedAddress *mail.Address - ) - - if parsedAddress, err = mail.ParseAddress(dataStr); err != nil { - return nil, fmt.Errorf("could not parse '%s' as a RFC5322 address: %w", dataStr, err) - } - - return *parsedAddress, nil + return *result, nil } } -// StringToURLHookFunc converts string types into a url.URL. +// StringToURLHookFunc converts string types into a url.URL or *url.URL. func StringToURLHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { var ptr bool @@ -47,37 +68,40 @@ func StringToURLHookFunc() mapstructure.DecodeHookFuncType { return data, nil } - ptr = t.Kind() == reflect.Ptr + kindStr := "url.URL" - typeURL := reflect.TypeOf(url.URL{}) + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } - if ptr && t.Elem() != typeURL { + expectedType := reflect.TypeOf(url.URL{}) + + if ptr && t.Elem() != expectedType { return data, nil - } else if !ptr && t != typeURL { + } else if !ptr && t != expectedType { return data, nil } dataStr := data.(string) - var parsedURL *url.URL + var result *url.URL - // Return an empty URL if there is an empty string. if dataStr != "" { - if parsedURL, err = url.Parse(dataStr); err != nil { - return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err) + if result, err = url.Parse(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) } } if ptr { - return parsedURL, nil + return result, nil } - // Return an empty URL if there is an empty string. - if parsedURL == nil { + if result == nil { return url.URL{}, nil } - return *parsedURL, nil + return *result, nil } } @@ -94,48 +118,51 @@ func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType { return data, nil } - typeTimeDuration := reflect.TypeOf(time.Hour) + kindStr := "time.Duration" if t.Kind() == reflect.Ptr { - if t.Elem() != typeTimeDuration { - return data, nil - } - ptr = true - } else if t != typeTimeDuration { + kindStr = "*" + kindStr + } + + expectedType := reflect.TypeOf(time.Duration(0)) + + if ptr && t.Elem() != expectedType { + return data, nil + } else if !ptr && t != expectedType { return data, nil } - var duration time.Duration + var result time.Duration switch { case f.Kind() == reflect.String: dataStr := data.(string) - if duration, err = utils.ParseDurationString(dataStr); err != nil { - return nil, err + if result, err = utils.ParseDurationString(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) } case f.Kind() == reflect.Int: seconds := data.(int) - duration = time.Second * time.Duration(seconds) + result = time.Second * time.Duration(seconds) case f.Kind() == reflect.Int32: seconds := data.(int32) - duration = time.Second * time.Duration(seconds) - case f == typeTimeDuration: - duration = data.(time.Duration) + result = time.Second * time.Duration(seconds) + case f == expectedType: + result = data.(time.Duration) case f.Kind() == reflect.Int64: seconds := data.(int64) - duration = time.Second * time.Duration(seconds) + result = time.Second * time.Duration(seconds) } if ptr { - return &duration, nil + return &result, nil } - return duration, nil + return result, nil } } @@ -148,27 +175,39 @@ func StringToRegexpFunc() mapstructure.DecodeHookFuncType { return data, nil } - ptr = t.Kind() == reflect.Ptr + kindStr := "regexp.Regexp" - typeRegexp := reflect.TypeOf(regexp.Regexp{}) + if t.Kind() == reflect.Ptr { + ptr = true + kindStr = "*" + kindStr + } - if ptr && t.Elem() != typeRegexp { + expectedType := reflect.TypeOf(regexp.Regexp{}) + + if ptr && t.Elem() != expectedType { return data, nil - } else if !ptr && t != typeRegexp { + } else if !ptr && t != expectedType { return data, nil } - regexStr := data.(string) + dataStr := data.(string) - pattern, err := regexp.Compile(regexStr) - if err != nil { - return nil, fmt.Errorf("could not parse '%s' as regexp: %w", regexStr, err) + var result *regexp.Regexp + + if dataStr != "" { + if result, err = regexp.Compile(dataStr); err != nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParse, dataStr, kindStr, err) + } } if ptr { - return pattern, nil + return result, nil } - return *pattern, nil + if result == nil { + return nil, fmt.Errorf(errFmtDecodeHookCouldNotParseEmptyValue, kindStr, errDecodeNonPtrMustHaveValue) + } + + return *result, nil } } diff --git a/internal/configuration/decode_hooks_test.go b/internal/configuration/decode_hooks_test.go index 8adeca21d..c34011221 100644 --- a/internal/configuration/decode_hooks_test.go +++ b/internal/configuration/decode_hooks_test.go @@ -1,6 +1,7 @@ -package configuration +package configuration_test import ( + "net/mail" "net/url" "reflect" "regexp" @@ -9,531 +10,754 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/authelia/authelia/v4/internal/configuration" ) -func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) { - hook := StringToURLHookFunc() - - var ( - from = "https://google.com/abc?a=123" - - result interface{} - err error - - resultTo string - resultPtrTo *time.Time - ok bool - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) - assert.NoError(t, err) - - resultTo, ok = result.(string) - assert.True(t, ok) - assert.Equal(t, from, resultTo) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) - assert.NoError(t, err) - - resultTo, ok = result.(string) - assert.True(t, ok) - assert.Equal(t, from, resultTo) -} - -func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) { - hook := StringToURLHookFunc() - - var ( - from = "" - - result interface{} - err error - - resultTo url.URL - resultPtrTo *url.URL - - ok bool - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) - assert.NoError(t, err) - - resultTo, ok = result.(url.URL) - assert.True(t, ok) - assert.Equal(t, "", resultTo.String()) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) - assert.NoError(t, err) - - resultPtrTo, ok = result.(*url.URL) - assert.True(t, ok) - assert.Nil(t, resultPtrTo) -} - -func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) { - hook := StringToURLHookFunc() - - var ( - from = "*(!&@#(!*^$%" - - result interface{} - err error - - resultTo url.URL - resultPtrTo *url.URL - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) - assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"") - assert.Nil(t, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) - assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"") - assert.Nil(t, result) -} - -func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) { - hook := StringToURLHookFunc() - - var ( - from = "https://google.com/abc?a=123" - - result interface{} - err error - - resultTo url.URL - resultPtrTo *url.URL - - ok bool - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) - assert.NoError(t, err) - - resultTo, ok = result.(url.URL) - assert.True(t, ok) - assert.Equal(t, "https", resultTo.Scheme) - assert.Equal(t, "google.com", resultTo.Host) - assert.Equal(t, "/abc", resultTo.Path) - assert.Equal(t, "a=123", resultTo.RawQuery) - - resultPtrTo, ok = result.(*url.URL) - assert.False(t, ok) - assert.Nil(t, resultPtrTo) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) - assert.NoError(t, err) - - resultPtrTo, ok = result.(*url.URL) - assert.True(t, ok) - assert.NotNil(t, resultPtrTo) - - assert.Equal(t, "https", resultTo.Scheme) - assert.Equal(t, "google.com", resultTo.Host) - assert.Equal(t, "/abc", resultTo.Path) - assert.Equal(t, "a=123", resultTo.RawQuery) - - resultTo, ok = result.(url.URL) - assert.False(t, ok) -} - -func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "1h" - expected = time.Hour - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "1y" - expected = time.Hour * 24 * 365 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "1M" - expected = time.Hour * 24 * 30 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "1w" - expected = time.Hour * 24 * 7 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "1d" - expected = time.Hour * 24 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = "abc" - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.EqualError(t, err, "could not parse 'abc' as a duration") - assert.Nil(t, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.EqualError(t, err, "could not parse 'abc' as a duration") - assert.Nil(t, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = 60 - expected = time.Second * 60 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = int32(120) - expected = time.Second * 120 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = int64(30) - expected = time.Second * 30 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = time.Second * 30 - expected = time.Second * 30 - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = int64(30) - - to string - ptrTo *string - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, from, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, from, result) -} - -func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = true - - to string - ptrTo *string - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, from, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, from, result) -} - -func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) { - hook := ToTimeDurationHookFunc() - - var ( - from = 0 - expected = time.Duration(0) - - to time.Duration - ptrTo *time.Duration - result interface{} - err error - ) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) - assert.NoError(t, err) - assert.Equal(t, &expected, result) -} - -func TestStringToRegexpFunc(t *testing.T) { - wantRegexp := func(regexpStr string) regexp.Regexp { - pattern := regexp.MustCompile(regexpStr) - - return *pattern - } - +func TestStringToMailAddressHookFunc(t *testing.T) { testCases := []struct { - desc string - have interface{} - want regexp.Regexp - wantPtr *regexp.Regexp - wantErr string - wantGroupNames []string + desc string + have interface{} + want interface{} + err string + decode bool }{ { - desc: "should not parse regexp with open paren", - have: "hello(test one two", - wantErr: "could not parse 'hello(test one two' as regexp: error parsing regexp: missing closing ): `hello(test one two`", + desc: "ShouldDecodeMailAddress", + have: "james@example.com", + want: mail.Address{Name: "", Address: "james@example.com"}, + decode: true, }, { - desc: "should parse valid regex", - have: "^(api|admin)$", - want: wantRegexp("^(api|admin)$"), - wantPtr: regexp.MustCompile("^(api|admin)$"), + desc: "ShouldDecodeMailAddressWithName", + have: "James ", + want: mail.Address{Name: "James", Address: "james@example.com"}, + decode: true, }, { - desc: "should parse valid regex with named groups", - have: "^(?Papi|admin)$", - want: wantRegexp("^(?Papi|admin)$"), - wantPtr: regexp.MustCompile("^(?Papi|admin)$"), - wantGroupNames: []string{"area"}, + desc: "ShouldDecodeMailAddressWithEmptyString", + have: "", + want: mail.Address{}, + decode: true, + }, + { + desc: "ShouldNotDecodeInvalidMailAddress", + have: "fred", + want: mail.Address{}, + err: "could not decode 'fred' to a mail.Address (RFC5322): mail: missing '@' or angle-addr", + decode: true, }, } - hook := StringToRegexpFunc() + hook := configuration.StringToMailAddressHookFunc() for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - t.Run("non-ptr", func(t *testing.T) { - result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) - if tc.wantErr != "" { - assert.EqualError(t, err, tc.wantErr) - assert.Nil(t, result) + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestStringToMailAddressHookFuncPointer(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + }{ + { + desc: "ShouldDecodeMailAddress", + have: "james@example.com", + want: &mail.Address{Name: "", Address: "james@example.com"}, + decode: true, + }, + { + desc: "ShouldDecodeMailAddressWithName", + have: "James ", + want: &mail.Address{Name: "James", Address: "james@example.com"}, + decode: true, + }, + { + desc: "ShouldDecodeMailAddressWithEmptyString", + have: "", + want: (*mail.Address)(nil), + decode: true, + }, + { + desc: "ShouldNotDecodeInvalidMailAddress", + have: "fred", + want: &mail.Address{}, + err: "could not decode 'fred' to a *mail.Address (RFC5322): mail: missing '@' or angle-addr", + decode: true, + }, + { + desc: "ShouldNotDecodeToInt", + have: "fred", + want: testInt32Ptr(4), + decode: false, + }, + } + + hook := configuration.StringToMailAddressHookFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestStringToURLHookFunc(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + }{ + { + desc: "ShouldDecodeURL", + have: "https://www.example.com:9090/abc?test=true", + want: url.URL{Scheme: "https", Host: "www.example.com:9090", Path: "/abc", RawQuery: "test=true"}, + decode: true, + }, + { + desc: "ShouldDecodeURLEmptyString", + have: "", + want: url.URL{}, + decode: true, + }, + { + desc: "ShouldNotDecodeToString", + have: "abc", + want: "", + decode: false, + }, + { + desc: "ShouldDecodeURLWithUserAndPassword", + have: "https://john:abc123@www.example.com:9090/abc?test=true", + want: url.URL{Scheme: "https", Host: "www.example.com:9090", Path: "/abc", RawQuery: "test=true", User: url.UserPassword("john", "abc123")}, + decode: true, + }, + { + desc: "ShouldNotDecodeInt", + have: 5, + want: url.URL{}, + decode: false, + }, + { + desc: "ShouldNotDecodeBool", + have: true, + want: url.URL{}, + decode: false, + }, + { + desc: "ShouldNotDecodeBadURL", + have: "*(!&@#(!*^$%", + want: url.URL{}, + err: "could not decode '*(!&@#(!*^$%' to a url.URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"", + decode: true, + }, + } + + hook := configuration.StringToURLHookFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestStringToURLHookFuncPointer(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + }{ + { + desc: "ShouldDecodeURL", + have: "https://www.example.com:9090/abc?test=true", + want: &url.URL{Scheme: "https", Host: "www.example.com:9090", Path: "/abc", RawQuery: "test=true"}, + decode: true, + }, + { + desc: "ShouldDecodeURLEmptyString", + have: "", + want: (*url.URL)(nil), + decode: true, + }, + { + desc: "ShouldDecodeURLWithUserAndPassword", + have: "https://john:abc123@www.example.com:9090/abc?test=true", + want: &url.URL{Scheme: "https", Host: "www.example.com:9090", Path: "/abc", RawQuery: "test=true", User: url.UserPassword("john", "abc123")}, + decode: true, + }, + { + desc: "ShouldNotDecodeInt", + have: 5, + want: &url.URL{}, + decode: false, + }, + { + desc: "ShouldNotDecodeBool", + have: true, + want: &url.URL{}, + decode: false, + }, + { + desc: "ShouldNotDecodeBadURL", + have: "*(!&@#(!*^$%", + want: &url.URL{}, + err: "could not decode '*(!&@#(!*^$%' to a *url.URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"", + decode: true, + }, + { + desc: "ShouldNotDecodeToInt", + have: "fred", + want: testInt32Ptr(4), + decode: false, + }, + } + + hook := configuration.StringToURLHookFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestToTimeDurationHookFunc(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + }{ + { + desc: "ShouldDecodeFourtyFiveSeconds", + have: "45s", + want: time.Second * 45, + decode: true, + }, + { + desc: "ShouldDecodeOneMinute", + have: "1m", + want: time.Minute, + decode: true, + }, + { + desc: "ShouldDecodeTwoHours", + have: "2h", + want: time.Hour * 2, + decode: true, + }, + { + desc: "ShouldDecodeThreeDays", + have: "3d", + want: time.Hour * 24 * 3, + decode: true, + }, + { + desc: "ShouldDecodeFourWeeks", + have: "4w", + want: time.Hour * 24 * 7 * 4, + decode: true, + }, + { + desc: "ShouldDecodeFiveMonths", + have: "5M", + want: time.Hour * 24 * 30 * 5, + decode: true, + }, + { + desc: "ShouldDecodeSixYears", + have: "6y", + want: time.Hour * 24 * 365 * 6, + decode: true, + }, + { + desc: "ShouldNotDecodeInvalidString", + have: "abc", + want: time.Duration(0), + err: "could not decode 'abc' to a time.Duration: could not parse 'abc' as a duration", + decode: true, + }, + { + desc: "ShouldDecodeIntToSeconds", + have: 60, + want: time.Second * 60, + decode: true, + }, + { + desc: "ShouldDecodeInt32ToSeconds", + have: int32(90), + want: time.Second * 90, + decode: true, + }, + { + desc: "ShouldDecodeInt64ToSeconds", + have: int64(120), + want: time.Second * 120, + decode: true, + }, + { + desc: "ShouldDecodeTimeDuration", + have: time.Second * 30, + want: time.Second * 30, + decode: true, + }, + { + desc: "ShouldNotDecodeToString", + have: int64(30), + want: "", + decode: false, + }, + { + desc: "ShouldDecodeFromIntZero", + have: 0, + want: time.Duration(0), + decode: true, + }, + { + desc: "ShouldNotDecodeFromBool", + have: true, + want: true, + }, + } + + hook := configuration.ToTimeDurationHookFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestToTimeDurationHookFuncPointer(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + }{ + { + desc: "ShouldDecodeFourtyFiveSeconds", + have: "45s", + want: testTimeDurationPtr(time.Second * 45), + decode: true, + }, + { + desc: "ShouldDecodeOneMinute", + have: "1m", + want: testTimeDurationPtr(time.Minute), + decode: true, + }, + { + desc: "ShouldDecodeTwoHours", + have: "2h", + want: testTimeDurationPtr(time.Hour * 2), + decode: true, + }, + { + desc: "ShouldDecodeThreeDays", + have: "3d", + want: testTimeDurationPtr(time.Hour * 24 * 3), + decode: true, + }, + { + desc: "ShouldDecodeFourWeeks", + have: "4w", + want: testTimeDurationPtr(time.Hour * 24 * 7 * 4), + decode: true, + }, + { + desc: "ShouldDecodeFiveMonths", + have: "5M", + want: testTimeDurationPtr(time.Hour * 24 * 30 * 5), + decode: true, + }, + { + desc: "ShouldDecodeSixYears", + have: "6y", + want: testTimeDurationPtr(time.Hour * 24 * 365 * 6), + decode: true, + }, + { + desc: "ShouldNotDecodeInvalidString", + have: "abc", + want: testTimeDurationPtr(time.Duration(0)), + err: "could not decode 'abc' to a *time.Duration: could not parse 'abc' as a duration", + decode: true, + }, + { + desc: "ShouldDecodeIntToSeconds", + have: 60, + want: testTimeDurationPtr(time.Second * 60), + decode: true, + }, + { + desc: "ShouldDecodeInt32ToSeconds", + have: int32(90), + want: testTimeDurationPtr(time.Second * 90), + decode: true, + }, + { + desc: "ShouldDecodeInt64ToSeconds", + have: int64(120), + want: testTimeDurationPtr(time.Second * 120), + decode: true, + }, + { + desc: "ShouldDecodeTimeDuration", + have: time.Second * 30, + want: testTimeDurationPtr(time.Second * 30), + decode: true, + }, + { + desc: "ShouldNotDecodeToString", + have: int64(30), + want: &testString, + decode: false, + }, + { + desc: "ShouldDecodeFromIntZero", + have: 0, + want: testTimeDurationPtr(time.Duration(0)), + decode: true, + }, + { + desc: "ShouldNotDecodeFromBool", + have: true, + want: &testTrue, + decode: false, + }, + } + + hook := configuration.ToTimeDurationHookFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} + +func TestStringToRegexpFunc(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + wantGrps []string + }{ + { + desc: "ShouldNotDecodeRegexpWithOpenParenthesis", + have: "hello(test one two", + want: regexp.Regexp{}, + err: "could not decode 'hello(test one two' to a regexp.Regexp: error parsing regexp: missing closing ): `hello(test one two`", + decode: true, + }, + { + desc: "ShouldDecodeValidRegex", + have: "^(api|admin)$", + want: *regexp.MustCompile(`^(api|admin)$`), + decode: true, + }, + { + desc: "ShouldDecodeValidRegexWithGroupNames", + have: "^(?Papi|admin)(one|two)$", + want: *regexp.MustCompile(`^(?Papi|admin)(one|two)$`), + decode: true, + wantGrps: []string{"area"}, + }, + { + desc: "ShouldNotDecodeFromInt32", + have: int32(20), + want: regexp.Regexp{}, + decode: false, + }, + { + desc: "ShouldNotDecodeFromBool", + have: false, + want: regexp.Regexp{}, + decode: false, + }, + { + desc: "ShouldNotDecodeToBool", + have: "^(?Papi|admin)(one|two)$", + want: testTrue, + decode: false, + }, + { + desc: "ShouldNotDecodeToInt32", + have: "^(?Papi|admin)(one|two)$", + want: testInt32Ptr(0), + decode: false, + }, + { + desc: "ShouldNotDecodeToMailAddress", + have: "^(?Papi|admin)(one|two)$", + want: mail.Address{}, + decode: false, + }, + { + desc: "ShouldErrOnDecodeEmptyString", + have: "", + want: regexp.Regexp{}, + err: "could not decode an empty value to a regexp.Regexp: must have a non-empty value", + decode: true, + }, + } + + hook := configuration.StringToRegexpFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + + pattern := result.(regexp.Regexp) + + var names []string + for _, name := range pattern.SubexpNames() { + if name != "" { + names = append(names, name) + } + } + + if len(tc.wantGrps) != 0 { + t.Run("MustHaveAllExpectedSubexpGroupNames", func(t *testing.T) { + for _, name := range tc.wantGrps { + assert.Contains(t, names, name) + } + }) + t.Run("MustNotHaveUnexpectedSubexpGroupNames", func(t *testing.T) { + for _, name := range names { + assert.Contains(t, tc.wantGrps, name) + } + }) } else { - assert.NoError(t, err) - require.Equal(t, tc.want, result) + t.Run("MustHaveNoSubexpGroupNames", func(t *testing.T) { + assert.Len(t, names, 0) + }) + } + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } + }) + } +} - resultRegexp := result.(regexp.Regexp) +func TestStringToRegexpFuncPointers(t *testing.T) { + testCases := []struct { + desc string + have interface{} + want interface{} + err string + decode bool + wantGrps []string + }{ + { + desc: "ShouldNotDecodeRegexpWithOpenParenthesis", + have: "hello(test one two", + want: ®exp.Regexp{}, + err: "could not decode 'hello(test one two' to a *regexp.Regexp: error parsing regexp: missing closing ): `hello(test one two`", + decode: true, + }, + { + desc: "ShouldDecodeValidRegex", + have: "^(api|admin)$", + want: regexp.MustCompile(`^(api|admin)$`), + decode: true, + }, + { + desc: "ShouldDecodeValidRegexWithGroupNames", + have: "^(?Papi|admin)(one|two)$", + want: regexp.MustCompile(`^(?Papi|admin)(one|two)$`), + decode: true, + wantGrps: []string{"area"}, + }, + { + desc: "ShouldNotDecodeFromInt32", + have: int32(20), + want: ®exp.Regexp{}, + decode: false, + }, + { + desc: "ShouldNotDecodeFromBool", + have: false, + want: ®exp.Regexp{}, + decode: false, + }, + { + desc: "ShouldNotDecodeToBool", + have: "^(?Papi|admin)(one|two)$", + want: &testTrue, + decode: false, + }, + { + desc: "ShouldNotDecodeToInt32", + have: "^(?Papi|admin)(one|two)$", + want: &testZero, + decode: false, + }, + { + desc: "ShouldNotDecodeToMailAddress", + have: "^(?Papi|admin)(one|two)$", + want: &mail.Address{}, + decode: false, + }, + { + desc: "ShouldDecodeEmptyStringToNil", + have: "", + want: (*regexp.Regexp)(nil), + decode: true, + }, + } - actualNames := resultRegexp.SubexpNames() + hook := configuration.StringToRegexpFunc() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.want), tc.have) + switch { + case !tc.decode: + assert.NoError(t, err) + assert.Equal(t, tc.have, result) + case tc.err == "": + assert.NoError(t, err) + require.Equal(t, tc.want, result) + + pattern := result.(*regexp.Regexp) + + if tc.want == (*regexp.Regexp)(nil) { + assert.Nil(t, pattern) + } else { var names []string - - for _, name := range actualNames { + for _, name := range pattern.SubexpNames() { if name != "" { names = append(names, name) } } - if len(tc.wantGroupNames) != 0 { - t.Run("must_have_all_expected_subexp_group_names", func(t *testing.T) { - for _, name := range tc.wantGroupNames { + if len(tc.wantGrps) != 0 { + t.Run("MustHaveAllExpectedSubexpGroupNames", func(t *testing.T) { + for _, name := range tc.wantGrps { assert.Contains(t, names, name) } }) - t.Run("must_not_have_unexpected_subexp_group_names", func(t *testing.T) { + t.Run("MustNotHaveUnexpectedSubexpGroupNames", func(t *testing.T) { for _, name := range names { - assert.Contains(t, tc.wantGroupNames, name) + assert.Contains(t, tc.wantGrps, name) } }) } else { - t.Run("must_have_no_subexp_group_names", func(t *testing.T) { + t.Run("MustHaveNoSubexpGroupNames", func(t *testing.T) { assert.Len(t, names, 0) }) } } - }) - t.Run("ptr", func(t *testing.T) { - result, err := hook(reflect.TypeOf(tc.have), reflect.TypeOf(tc.wantPtr), tc.have) - if tc.wantErr != "" { - assert.EqualError(t, err, tc.wantErr) - assert.Nil(t, result) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.wantPtr, result) - - resultRegexp := result.(*regexp.Regexp) - - actualNames := resultRegexp.SubexpNames() - var names []string - - for _, name := range actualNames { - if name != "" { - names = append(names, name) - } - } - - if len(tc.wantGroupNames) != 0 { - t.Run("must_have_all_expected_names", func(t *testing.T) { - for _, name := range tc.wantGroupNames { - assert.Contains(t, names, name) - } - }) - - t.Run("must_not_have_unexpected_names", func(t *testing.T) { - for _, name := range names { - assert.Contains(t, tc.wantGroupNames, name) - } - }) - } else { - assert.Len(t, names, 0) - } - } - }) + default: + assert.EqualError(t, err, tc.err) + assert.Nil(t, result) + } }) } } + +func testInt32Ptr(i int32) *int32 { + return &i +} + +func testTimeDurationPtr(t time.Duration) *time.Duration { + return &t +} + +var ( + testTrue = true + testZero int32 + testString = "" +) diff --git a/internal/configuration/provider_test.go b/internal/configuration/provider_test.go index 94a678c43..52cfecfce 100644 --- a/internal/configuration/provider_test.go +++ b/internal/configuration/provider_test.go @@ -236,7 +236,7 @@ func TestShouldRaiseErrOnInvalidNotifierSMTPSender(t *testing.T) { require.Len(t, val.Errors(), 1) assert.Len(t, val.Warnings(), 0) - assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'notifier.smtp.sender': could not parse 'admin' as a RFC5322 address: mail: missing '@' or angle-addr") + assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'notifier.smtp.sender': could not decode 'admin' to a mail.Address (RFC5322): mail: missing '@' or angle-addr") } func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) { @@ -343,7 +343,7 @@ func TestShouldErrOnParseInvalidRegex(t *testing.T) { require.Len(t, val.Errors(), 1) assert.Len(t, val.Warnings(), 0) - assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'access_control.rules[0].domain_regex[0]': could not parse '^\\K(public|public2).example.com$' as regexp: error parsing regexp: invalid escape sequence: `\\K`") + assert.EqualError(t, val.Errors()[0], "error occurred during unmarshalling configuration: 1 error(s) decoding:\n\n* error decoding 'access_control.rules[0].domain_regex[0]': could not decode '^\\K(public|public2).example.com$' to a regexp.Regexp: error parsing regexp: invalid escape sequence: `\\K`") } func TestShouldNotReadConfigurationOnFSAccessDenied(t *testing.T) {