refactor: include url hook func (#3022)

This adds a hook func for url.URL and *url.URL types to the configuration.
pull/3021/head^2
James Elliott 2022-03-16 16:16:46 +11:00 committed by GitHub
parent 99326c2688
commit dbe290a1c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 205 additions and 35 deletions

View File

@ -3,6 +3,7 @@ package configuration
import ( import (
"fmt" "fmt"
"net/mail" "net/mail"
"net/url"
"reflect" "reflect"
"time" "time"
@ -11,8 +12,8 @@ import (
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// StringToMailAddressFunc decodes a string into a mail.Address. // StringToMailAddressHookFunc decodes a string into a mail.Address.
func StringToMailAddressFunc() mapstructure.DecodeHookFunc { func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) { if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
return data, nil return data, nil
@ -36,12 +37,53 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
} }
} }
// ToTimeDurationFunc converts string and integer types to a time.Duration. // StringToURLHookFunc converts string types into a url.URL.
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType { func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var ( var ptr bool
ptr bool
) if f.Kind() != reflect.String {
return data, nil
}
ptr = t.Kind() == reflect.Ptr
typeURL := reflect.TypeOf(url.URL{})
if ptr && t.Elem() != typeURL {
return data, nil
} else if !ptr && t != typeURL {
return data, nil
}
dataStr := data.(string)
var parsedURL *url.URL
// Return an empty URL if there is an empty string.
if dataStr != "" {
if parsedURL, err = url.Parse(dataStr); err != nil {
return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err)
}
}
if ptr {
return parsedURL, nil
}
// Return an empty URL if there is an empty string.
if parsedURL == nil {
return url.URL{}, nil
}
return *parsedURL, nil
}
}
// ToTimeDurationHookFunc converts string and integer types to a time.Duration.
func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var ptr bool
switch f.Kind() { switch f.Kind() {
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64: case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:

View File

@ -1,6 +1,7 @@
package configuration package configuration
import ( import (
"net/url"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -8,8 +9,134 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) { func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) {
hook := ToTimeDurationFunc() hook := StringToURLHookFunc()
var (
from = "https://google.com/abc?a=123"
result interface{}
err error
resultTo string
resultPtrTo *time.Time
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(string)
assert.True(t, ok)
assert.Equal(t, from, resultTo)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultTo, ok = result.(string)
assert.True(t, ok)
assert.Equal(t, from, resultTo)
}
func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = ""
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(url.URL)
assert.True(t, ok)
assert.Equal(t, "", resultTo.String())
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultPtrTo, ok = result.(*url.URL)
assert.True(t, ok)
assert.Nil(t, resultPtrTo)
}
func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = "*(!&@#(!*^$%"
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
assert.Nil(t, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
assert.Nil(t, result)
}
func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) {
hook := StringToURLHookFunc()
var (
from = "https://google.com/abc?a=123"
result interface{}
err error
resultTo url.URL
resultPtrTo *url.URL
ok bool
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
assert.NoError(t, err)
resultTo, ok = result.(url.URL)
assert.True(t, ok)
assert.Equal(t, "https", resultTo.Scheme)
assert.Equal(t, "google.com", resultTo.Host)
assert.Equal(t, "/abc", resultTo.Path)
assert.Equal(t, "a=123", resultTo.RawQuery)
resultPtrTo, ok = result.(*url.URL)
assert.False(t, ok)
assert.Nil(t, resultPtrTo)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
assert.NoError(t, err)
resultPtrTo, ok = result.(*url.URL)
assert.True(t, ok)
assert.NotNil(t, resultPtrTo)
assert.Equal(t, "https", resultTo.Scheme)
assert.Equal(t, "google.com", resultTo.Host)
assert.Equal(t, "/abc", resultTo.Path)
assert.Equal(t, "a=123", resultTo.RawQuery)
resultTo, ok = result.(url.URL)
assert.False(t, ok)
}
func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) {
hook := ToTimeDurationHookFunc()
var ( var (
from = "1h" from = "1h"
@ -30,8 +157,8 @@ func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = "1y" from = "1y"
@ -52,8 +179,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = "1M" from = "1M"
@ -74,8 +201,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = "1w" from = "1w"
@ -96,8 +223,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = "1d" from = "1d"
@ -118,8 +245,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) { func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = "abc" from = "abc"
@ -139,8 +266,8 @@ func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T
assert.Nil(t, result) assert.Nil(t, result)
} }
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = 60 from = 60
@ -161,8 +288,8 @@ func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = int32(120) from = int32(120)
@ -183,8 +310,8 @@ func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = int64(30) from = int64(30)
@ -205,8 +332,8 @@ func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = time.Second * 30 from = time.Second * 30
@ -227,8 +354,8 @@ func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
assert.Equal(t, &expected, result) assert.Equal(t, &expected, result)
} }
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) { func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = int64(30) from = int64(30)
@ -248,8 +375,8 @@ func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
assert.Equal(t, from, result) assert.Equal(t, from, result)
} }
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) { func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = true from = true
@ -269,8 +396,8 @@ func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
assert.Equal(t, from, result) assert.Equal(t, from, result)
} }
func TestToTimeDurationFunc_ShouldParse_FromZero(t *testing.T) { func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) {
hook := ToTimeDurationFunc() hook := ToTimeDurationHookFunc()
var ( var (
from = 0 from = 0

View File

@ -44,8 +44,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
DecoderConfig: &mapstructure.DecoderConfig{ DecoderConfig: &mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc( DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToSliceHookFunc(","), mapstructure.StringToSliceHookFunc(","),
StringToMailAddressFunc(), StringToMailAddressHookFunc(),
ToTimeDurationFunc(), ToTimeDurationHookFunc(),
StringToURLHookFunc(),
), ),
Metadata: nil, Metadata: nil,
Result: o, Result: o,