refactor: include url hook func (#3022)

This adds a hook func for url.URL and *url.URL types to the configuration.
pull/3021/head^2
James Elliott 2022-03-16 16:16:46 +11:00 committed by GitHub
parent 99326c2688
commit dbe290a1c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 205 additions and 35 deletions

View File

@ -3,6 +3,7 @@ package configuration
import (
"fmt"
"net/mail"
"net/url"
"reflect"
"time"
@ -11,8 +12,8 @@ import (
"github.com/authelia/authelia/v4/internal/utils"
)
// StringToMailAddressFunc decodes a string into a mail.Address.
func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
// StringToMailAddressHookFunc decodes a string into a mail.Address.
func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
return data, nil
@ -36,12 +37,53 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
}
}
// ToTimeDurationFunc converts string and integer types to a time.Duration.
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
// StringToURLHookFunc converts string types into a url.URL.
func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var (
ptr bool
)
var ptr bool
if f.Kind() != reflect.String {
return data, nil
}
ptr = t.Kind() == reflect.Ptr
typeURL := reflect.TypeOf(url.URL{})
if ptr && t.Elem() != typeURL {
return data, nil
} else if !ptr && t != typeURL {
return data, nil
}
dataStr := data.(string)
var parsedURL *url.URL
// Return an empty URL if there is an empty string.
if dataStr != "" {
if parsedURL, err = url.Parse(dataStr); err != nil {
return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err)
}
}
if ptr {
return parsedURL, nil
}
// Return an empty URL if there is an empty string.
if parsedURL == nil {
return url.URL{}, nil
}
return *parsedURL, nil
}
}
// ToTimeDurationHookFunc converts string and integer types to a time.Duration.
func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var ptr bool
switch f.Kind() {
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:

View File

@ -1,6 +1,7 @@
package configuration
import (
"net/url"
"reflect"
"testing"
"time"
@ -8,8 +9,134 @@ import (
"github.com/stretchr/testify/assert"
)
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
hook := ToTimeDurationFunc()
func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = "https://google.com/abc?a=123"
result interface{}
err error
resultTo string
resultPtrTo *time.Time
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(string)
assert.True(t, ok)
assert.Equal(t, from, resultTo)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultTo, ok = result.(string)
assert.True(t, ok)
assert.Equal(t, from, resultTo)
}
func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = ""
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(url.URL)
assert.True(t, ok)
assert.Equal(t, "", resultTo.String())
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultPtrTo, ok = result.(*url.URL)
assert.True(t, ok)
assert.Nil(t, resultPtrTo)
}
func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = "*(!&@#(!*^$%"
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
assert.Nil(t, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
assert.Nil(t, result)
}
func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = "https://google.com/abc?a=123"
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(url.URL)
assert.True(t, ok)
assert.Equal(t, "https", resultTo.Scheme)
assert.Equal(t, "google.com", resultTo.Host)
assert.Equal(t, "/abc", resultTo.Path)
assert.Equal(t, "a=123", resultTo.RawQuery)
resultPtrTo, ok = result.(*url.URL)
assert.False(t, ok)
assert.Nil(t, resultPtrTo)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultPtrTo, ok = result.(*url.URL)
assert.True(t, ok)
assert.NotNil(t, resultPtrTo)
assert.Equal(t, "https", resultTo.Scheme)
assert.Equal(t, "google.com", resultTo.Host)
assert.Equal(t, "/abc", resultTo.Path)
assert.Equal(t, "a=123", resultTo.RawQuery)
resultTo, ok = result.(url.URL)
assert.False(t, ok)
}
func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "1h"
@ -30,8 +157,8 @@ func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "1y"
@ -52,8 +179,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "1M"
@ -74,8 +201,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "1w"
@ -96,8 +223,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "1d"
@ -118,8 +245,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = "abc"
@ -139,8 +266,8 @@ func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T
assert.Nil(t, result)
}
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = 60
@ -161,8 +288,8 @@ func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = int32(120)
@ -183,8 +310,8 @@ func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = int64(30)
@ -205,8 +332,8 @@ func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = time.Second * 30
@ -227,8 +354,8 @@ func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = int64(30)
@ -248,8 +375,8 @@ func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
assert.Equal(t, from, result)
}
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = true
@ -269,8 +396,8 @@ func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
assert.Equal(t, from, result)
}
func TestToTimeDurationFunc_ShouldParse_FromZero(t *testing.T) {
hook := ToTimeDurationFunc()
func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) {
hook := ToTimeDurationHookFunc()
var (
from = 0

View File

@ -44,8 +44,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
DecoderConfig: &mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToSliceHookFunc(","),
StringToMailAddressFunc(),
ToTimeDurationFunc(),
StringToMailAddressHookFunc(),
ToTimeDurationHookFunc(),
StringToURLHookFunc(),
),
Metadata: nil,
Result: o,