refactor: include url hook func (#3022)
This adds a hook func for url.URL and *url.URL types to the configuration.pull/3021/head^2
parent
99326c2688
commit
dbe290a1c9
|
@ -3,6 +3,7 @@ package configuration
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,8 +12,8 @@ import (
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StringToMailAddressFunc decodes a string into a mail.Address.
|
// StringToMailAddressHookFunc decodes a string into a mail.Address.
|
||||||
func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
|
func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||||
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
|
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
|
||||||
return data, nil
|
return data, nil
|
||||||
|
@ -36,12 +37,53 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToTimeDurationFunc converts string and integer types to a time.Duration.
|
// StringToURLHookFunc converts string types into a url.URL.
|
||||||
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
|
func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
|
||||||
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||||
var (
|
var ptr bool
|
||||||
ptr bool
|
|
||||||
)
|
if f.Kind() != reflect.String {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr = t.Kind() == reflect.Ptr
|
||||||
|
|
||||||
|
typeURL := reflect.TypeOf(url.URL{})
|
||||||
|
|
||||||
|
if ptr && t.Elem() != typeURL {
|
||||||
|
return data, nil
|
||||||
|
} else if !ptr && t != typeURL {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dataStr := data.(string)
|
||||||
|
|
||||||
|
var parsedURL *url.URL
|
||||||
|
|
||||||
|
// Return an empty URL if there is an empty string.
|
||||||
|
if dataStr != "" {
|
||||||
|
if parsedURL, err = url.Parse(dataStr); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ptr {
|
||||||
|
return parsedURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an empty URL if there is an empty string.
|
||||||
|
if parsedURL == nil {
|
||||||
|
return url.URL{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return *parsedURL, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTimeDurationHookFunc converts string and integer types to a time.Duration.
|
||||||
|
func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
|
||||||
|
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||||
|
var ptr bool
|
||||||
|
|
||||||
switch f.Kind() {
|
switch f.Kind() {
|
||||||
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package configuration
|
package configuration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -8,8 +9,134 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
|
func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := StringToURLHookFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "https://google.com/abc?a=123"
|
||||||
|
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
resultTo string
|
||||||
|
resultPtrTo *time.Time
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultTo, ok = result.(string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, from, resultTo)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultTo, ok = result.(string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, from, resultTo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) {
|
||||||
|
hook := StringToURLHookFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = ""
|
||||||
|
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
resultTo url.URL
|
||||||
|
resultPtrTo *url.URL
|
||||||
|
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultTo, ok = result.(url.URL)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "", resultTo.String())
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultPtrTo, ok = result.(*url.URL)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Nil(t, resultPtrTo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) {
|
||||||
|
hook := StringToURLHookFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "*(!&@#(!*^$%"
|
||||||
|
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
resultTo url.URL
|
||||||
|
resultPtrTo *url.URL
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||||
|
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
|
||||||
|
assert.Nil(t, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||||
|
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
|
||||||
|
assert.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) {
|
||||||
|
hook := StringToURLHookFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "https://google.com/abc?a=123"
|
||||||
|
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
|
||||||
|
resultTo url.URL
|
||||||
|
resultPtrTo *url.URL
|
||||||
|
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultTo, ok = result.(url.URL)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "https", resultTo.Scheme)
|
||||||
|
assert.Equal(t, "google.com", resultTo.Host)
|
||||||
|
assert.Equal(t, "/abc", resultTo.Path)
|
||||||
|
assert.Equal(t, "a=123", resultTo.RawQuery)
|
||||||
|
|
||||||
|
resultPtrTo, ok = result.(*url.URL)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Nil(t, resultPtrTo)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resultPtrTo, ok = result.(*url.URL)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotNil(t, resultPtrTo)
|
||||||
|
|
||||||
|
assert.Equal(t, "https", resultTo.Scheme)
|
||||||
|
assert.Equal(t, "google.com", resultTo.Host)
|
||||||
|
assert.Equal(t, "/abc", resultTo.Path)
|
||||||
|
assert.Equal(t, "a=123", resultTo.RawQuery)
|
||||||
|
|
||||||
|
resultTo, ok = result.(url.URL)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) {
|
||||||
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "1h"
|
from = "1h"
|
||||||
|
@ -30,8 +157,8 @@ func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "1y"
|
from = "1y"
|
||||||
|
@ -52,8 +179,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "1M"
|
from = "1M"
|
||||||
|
@ -74,8 +201,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "1w"
|
from = "1w"
|
||||||
|
@ -96,8 +223,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "1d"
|
from = "1d"
|
||||||
|
@ -118,8 +245,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = "abc"
|
from = "abc"
|
||||||
|
@ -139,8 +266,8 @@ func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T
|
||||||
assert.Nil(t, result)
|
assert.Nil(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = 60
|
from = 60
|
||||||
|
@ -161,8 +288,8 @@ func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = int32(120)
|
from = int32(120)
|
||||||
|
@ -183,8 +310,8 @@ func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = int64(30)
|
from = int64(30)
|
||||||
|
@ -205,8 +332,8 @@ func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = time.Second * 30
|
from = time.Second * 30
|
||||||
|
@ -227,8 +354,8 @@ func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
|
||||||
assert.Equal(t, &expected, result)
|
assert.Equal(t, &expected, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = int64(30)
|
from = int64(30)
|
||||||
|
@ -248,8 +375,8 @@ func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
||||||
assert.Equal(t, from, result)
|
assert.Equal(t, from, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = true
|
from = true
|
||||||
|
@ -269,8 +396,8 @@ func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
|
||||||
assert.Equal(t, from, result)
|
assert.Equal(t, from, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToTimeDurationFunc_ShouldParse_FromZero(t *testing.T) {
|
func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) {
|
||||||
hook := ToTimeDurationFunc()
|
hook := ToTimeDurationHookFunc()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
from = 0
|
from = 0
|
||||||
|
|
|
@ -44,8 +44,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
|
||||||
DecoderConfig: &mapstructure.DecoderConfig{
|
DecoderConfig: &mapstructure.DecoderConfig{
|
||||||
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
||||||
mapstructure.StringToSliceHookFunc(","),
|
mapstructure.StringToSliceHookFunc(","),
|
||||||
StringToMailAddressFunc(),
|
StringToMailAddressHookFunc(),
|
||||||
ToTimeDurationFunc(),
|
ToTimeDurationHookFunc(),
|
||||||
|
StringToURLHookFunc(),
|
||||||
),
|
),
|
||||||
Metadata: nil,
|
Metadata: nil,
|
||||||
Result: o,
|
Result: o,
|
||||||
|
|
Loading…
Reference in New Issue