diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go index 939347639..2bba3f9be 100644 --- a/internal/configuration/decode_hooks.go +++ b/internal/configuration/decode_hooks.go @@ -3,6 +3,7 @@ package configuration import ( "fmt" "net/mail" + "net/url" "reflect" "time" @@ -11,8 +12,8 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) -// StringToMailAddressFunc decodes a string into a mail.Address. -func StringToMailAddressFunc() mapstructure.DecodeHookFunc { +// StringToMailAddressHookFunc decodes a string into a mail.Address. +func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) { return data, nil @@ -36,12 +37,53 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc { } } -// ToTimeDurationFunc converts string and integer types to a time.Duration. -func ToTimeDurationFunc() mapstructure.DecodeHookFuncType { +// StringToURLHookFunc converts string types into a url.URL. +func StringToURLHookFunc() mapstructure.DecodeHookFuncType { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { - var ( - ptr bool - ) + var ptr bool + + if f.Kind() != reflect.String { + return data, nil + } + + ptr = t.Kind() == reflect.Ptr + + typeURL := reflect.TypeOf(url.URL{}) + + if ptr && t.Elem() != typeURL { + return data, nil + } else if !ptr && t != typeURL { + return data, nil + } + + dataStr := data.(string) + + var parsedURL *url.URL + + // Return an empty URL if there is an empty string. + if dataStr != "" { + if parsedURL, err = url.Parse(dataStr); err != nil { + return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err) + } + } + + if ptr { + return parsedURL, nil + } + + // Return an empty URL if there is an empty string. + if parsedURL == nil { + return url.URL{}, nil + } + + return *parsedURL, nil + } +} + +// ToTimeDurationHookFunc converts string and integer types to a time.Duration. +func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { + var ptr bool switch f.Kind() { case reflect.String, reflect.Int, reflect.Int32, reflect.Int64: diff --git a/internal/configuration/decode_hooks_test.go b/internal/configuration/decode_hooks_test.go index 41b665e0c..c0ced716d 100644 --- a/internal/configuration/decode_hooks_test.go +++ b/internal/configuration/decode_hooks_test.go @@ -1,6 +1,7 @@ package configuration import ( + "net/url" "reflect" "testing" "time" @@ -8,8 +9,134 @@ import ( "github.com/stretchr/testify/assert" ) -func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) { - hook := ToTimeDurationFunc() +func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) { + hook := StringToURLHookFunc() + + var ( + from = "https://google.com/abc?a=123" + + result interface{} + err error + + resultTo string + resultPtrTo *time.Time + ok bool + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) + assert.NoError(t, err) + + resultTo, ok = result.(string) + assert.True(t, ok) + assert.Equal(t, from, resultTo) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) + assert.NoError(t, err) + + resultTo, ok = result.(string) + assert.True(t, ok) + assert.Equal(t, from, resultTo) +} + +func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) { + hook := StringToURLHookFunc() + + var ( + from = "" + + result interface{} + err error + + resultTo url.URL + resultPtrTo *url.URL + + ok bool + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) + assert.NoError(t, err) + + resultTo, ok = result.(url.URL) + assert.True(t, ok) + assert.Equal(t, "", resultTo.String()) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) + assert.NoError(t, err) + + resultPtrTo, ok = result.(*url.URL) + assert.True(t, ok) + assert.Nil(t, resultPtrTo) +} + +func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) { + hook := StringToURLHookFunc() + + var ( + from = "*(!&@#(!*^$%" + + result interface{} + err error + + resultTo url.URL + resultPtrTo *url.URL + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) + assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"") + assert.Nil(t, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) + assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"") + assert.Nil(t, result) +} + +func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) { + hook := StringToURLHookFunc() + + var ( + from = "https://google.com/abc?a=123" + + result interface{} + err error + + resultTo url.URL + resultPtrTo *url.URL + + ok bool + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from) + assert.NoError(t, err) + + resultTo, ok = result.(url.URL) + assert.True(t, ok) + assert.Equal(t, "https", resultTo.Scheme) + assert.Equal(t, "google.com", resultTo.Host) + assert.Equal(t, "/abc", resultTo.Path) + assert.Equal(t, "a=123", resultTo.RawQuery) + + resultPtrTo, ok = result.(*url.URL) + assert.False(t, ok) + assert.Nil(t, resultPtrTo) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from) + assert.NoError(t, err) + + resultPtrTo, ok = result.(*url.URL) + assert.True(t, ok) + assert.NotNil(t, resultPtrTo) + + assert.Equal(t, "https", resultTo.Scheme) + assert.Equal(t, "google.com", resultTo.Host) + assert.Equal(t, "/abc", resultTo.Path) + assert.Equal(t, "a=123", resultTo.RawQuery) + + resultTo, ok = result.(url.URL) + assert.False(t, ok) +} + +func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "1h" @@ -30,8 +157,8 @@ func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "1y" @@ -52,8 +179,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "1M" @@ -74,8 +201,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "1w" @@ -96,8 +223,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "1d" @@ -118,8 +245,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = "abc" @@ -139,8 +266,8 @@ func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T assert.Nil(t, result) } -func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = 60 @@ -161,8 +288,8 @@ func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = int32(120) @@ -183,8 +310,8 @@ func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = int64(30) @@ -205,8 +332,8 @@ func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = time.Second * 30 @@ -227,8 +354,8 @@ func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) { assert.Equal(t, &expected, result) } -func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = int64(30) @@ -248,8 +375,8 @@ func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) { assert.Equal(t, from, result) } -func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = true @@ -269,8 +396,8 @@ func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) { assert.Equal(t, from, result) } -func TestToTimeDurationFunc_ShouldParse_FromZero(t *testing.T) { - hook := ToTimeDurationFunc() +func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) { + hook := ToTimeDurationHookFunc() var ( from = 0 diff --git a/internal/configuration/provider.go b/internal/configuration/provider.go index 70f8baf38..98758ab7a 100644 --- a/internal/configuration/provider.go +++ b/internal/configuration/provider.go @@ -44,8 +44,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte DecoderConfig: &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( mapstructure.StringToSliceHookFunc(","), - StringToMailAddressFunc(), - ToTimeDurationFunc(), + StringToMailAddressHookFunc(), + ToTimeDurationHookFunc(), + StringToURLHookFunc(), ), Metadata: nil, Result: o,