test: add misc missing tests (#5479)
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>pull/5481/head
parent
e784a72735
commit
2e8a460a66
|
@ -0,0 +1,20 @@
|
||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewPrometheus(t *testing.T) {
|
||||||
|
p := NewPrometheus()
|
||||||
|
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
|
||||||
|
p.RecordRequest("400", "GET", time.Second)
|
||||||
|
p.RecordAuthz("400")
|
||||||
|
p.RecordAuthn(true, false, "WebAuthn")
|
||||||
|
p.RecordAuthn(true, false, "1fa")
|
||||||
|
p.RecordAuthenticationDuration(true, time.Second)
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package random
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/big"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCryptographical(t *testing.T) {
|
||||||
|
p := &Cryptographical{}
|
||||||
|
|
||||||
|
data := make([]byte, 10)
|
||||||
|
|
||||||
|
n, err := p.Read(data)
|
||||||
|
assert.Equal(t, 10, n)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
data2, err := p.BytesErr()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
data2 = p.Bytes()
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
|
||||||
|
assert.Len(t, data2, 74)
|
||||||
|
|
||||||
|
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 76)
|
||||||
|
|
||||||
|
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
strdata := p.StringCustom(10, CharSetAlphabetic)
|
||||||
|
assert.Len(t, strdata, 10)
|
||||||
|
|
||||||
|
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, strdata, 11)
|
||||||
|
|
||||||
|
i := p.Intn(999)
|
||||||
|
assert.Greater(t, i, 0)
|
||||||
|
assert.Less(t, i, 999)
|
||||||
|
|
||||||
|
i, err = p.IntnErr(999)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Greater(t, i, 0)
|
||||||
|
assert.Less(t, i, 999)
|
||||||
|
|
||||||
|
i, err = p.IntnErr(-4)
|
||||||
|
assert.EqualError(t, err, "n must be more than 0")
|
||||||
|
assert.Equal(t, 0, i)
|
||||||
|
|
||||||
|
bi := p.Int(big.NewInt(999))
|
||||||
|
assert.Greater(t, bi.Int64(), int64(0))
|
||||||
|
assert.Less(t, bi.Int64(), int64(999))
|
||||||
|
|
||||||
|
bi = p.Int(nil)
|
||||||
|
assert.Equal(t, int64(-1), bi.Int64())
|
||||||
|
|
||||||
|
bi, err = p.IntErr(nil)
|
||||||
|
assert.Nil(t, bi)
|
||||||
|
assert.EqualError(t, err, "max is required")
|
||||||
|
|
||||||
|
bi, err = p.IntErr(big.NewInt(-1))
|
||||||
|
assert.Nil(t, bi)
|
||||||
|
assert.EqualError(t, err, "max must be 1 or more")
|
||||||
|
|
||||||
|
prime, err := p.Prime(64)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, prime)
|
||||||
|
}
|
|
@ -112,6 +112,10 @@ func (r *Mathematical) Intn(n int) int {
|
||||||
|
|
||||||
// IntnErr returns a random int error combination with a maximum of n.
|
// IntnErr returns a random int error combination with a maximum of n.
|
||||||
func (r *Mathematical) IntnErr(n int) (output int, err error) {
|
func (r *Mathematical) IntnErr(n int) (output int, err error) {
|
||||||
|
if n <= 0 {
|
||||||
|
return 0, fmt.Errorf("n must be more than 0")
|
||||||
|
}
|
||||||
|
|
||||||
return r.Intn(n), nil
|
return r.Intn(n), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,15 +136,11 @@ func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) {
|
||||||
return nil, fmt.Errorf("max is required")
|
return nil, fmt.Errorf("max is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
if max.Sign() <= 0 {
|
if max.Int64() <= 0 {
|
||||||
return nil, fmt.Errorf("max must be 1 or more")
|
return nil, fmt.Errorf("max must be 1 or more")
|
||||||
}
|
}
|
||||||
|
|
||||||
r.lock.Lock()
|
return big.NewInt(int64(r.Intn(int(max.Int64())))), nil
|
||||||
|
|
||||||
defer r.lock.Unlock()
|
|
||||||
|
|
||||||
return big.NewInt(int64(r.Intn(max.Sign()))), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any
|
// Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
package random
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/big"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMathematical(t *testing.T) {
|
||||||
|
p := NewMathematical()
|
||||||
|
|
||||||
|
data := make([]byte, 10)
|
||||||
|
|
||||||
|
n, err := p.Read(data)
|
||||||
|
assert.Equal(t, 10, n)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
data2, err := p.BytesErr()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
data2 = p.Bytes()
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
|
||||||
|
assert.Len(t, data2, 74)
|
||||||
|
|
||||||
|
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 76)
|
||||||
|
|
||||||
|
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, data2, 72)
|
||||||
|
|
||||||
|
strdata := p.StringCustom(10, CharSetAlphabetic)
|
||||||
|
assert.Len(t, strdata, 10)
|
||||||
|
|
||||||
|
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, strdata, 11)
|
||||||
|
|
||||||
|
i := p.Intn(999)
|
||||||
|
assert.Greater(t, i, 0)
|
||||||
|
assert.Less(t, i, 999)
|
||||||
|
|
||||||
|
i, err = p.IntnErr(999)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Greater(t, i, 0)
|
||||||
|
assert.Less(t, i, 999)
|
||||||
|
|
||||||
|
i, err = p.IntnErr(-4)
|
||||||
|
assert.EqualError(t, err, "n must be more than 0")
|
||||||
|
assert.Equal(t, 0, i)
|
||||||
|
|
||||||
|
bi := p.Int(big.NewInt(999))
|
||||||
|
assert.Greater(t, bi.Int64(), int64(0))
|
||||||
|
assert.Less(t, bi.Int64(), int64(999))
|
||||||
|
|
||||||
|
bi = p.Int(nil)
|
||||||
|
assert.Equal(t, int64(-1), bi.Int64())
|
||||||
|
|
||||||
|
bi, err = p.IntErr(nil)
|
||||||
|
assert.Nil(t, bi)
|
||||||
|
assert.EqualError(t, err, "max is required")
|
||||||
|
|
||||||
|
bi, err = p.IntErr(big.NewInt(-1))
|
||||||
|
assert.Nil(t, bi)
|
||||||
|
assert.EqualError(t, err, "max must be 1 or more")
|
||||||
|
|
||||||
|
prime, err := p.Prime(64)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, prime)
|
||||||
|
}
|
|
@ -12,12 +12,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRegulator create a regulator instance.
|
// NewRegulator create a regulator instance.
|
||||||
func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
func NewRegulator(config schema.RegulationConfiguration, store storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
||||||
return &Regulator{
|
return &Regulator{
|
||||||
enabled: config.MaxRetries > 0,
|
enabled: config.MaxRetries > 0,
|
||||||
storageProvider: provider,
|
store: store,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ func NewRegulator(config schema.RegulationConfiguration, provider storage.Regula
|
||||||
func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error {
|
func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error {
|
||||||
ctx.RecordAuthn(successful, banned, strings.ToLower(authType))
|
ctx.RecordAuthn(successful, banned, strings.ToLower(authType))
|
||||||
|
|
||||||
return r.storageProvider.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{
|
return r.store.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{
|
||||||
Time: r.clock.Now(),
|
Time: r.clock.Now(),
|
||||||
Successful: successful,
|
Successful: successful,
|
||||||
Banned: banned,
|
Banned: banned,
|
||||||
|
@ -46,7 +46,7 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
|
||||||
return time.Time{}, nil
|
return time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
|
attempts, err := r.store.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, nil
|
return time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,46 +1,69 @@
|
||||||
package regulation_test
|
package regulation_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"fmt"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/mocks"
|
"github.com/authelia/authelia/v4/internal/mocks"
|
||||||
"github.com/authelia/authelia/v4/internal/model"
|
"github.com/authelia/authelia/v4/internal/model"
|
||||||
"github.com/authelia/authelia/v4/internal/regulation"
|
"github.com/authelia/authelia/v4/internal/regulation"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type RegulatorSuite struct {
|
type RegulatorSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
|
|
||||||
ctx context.Context
|
mock *mocks.MockAutheliaCtx
|
||||||
ctrl *gomock.Controller
|
|
||||||
storageMock *mocks.MockStorage
|
|
||||||
config schema.RegulationConfiguration
|
|
||||||
clock utils.TestingClock
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RegulatorSuite) SetupTest() {
|
func (s *RegulatorSuite) SetupTest() {
|
||||||
s.ctrl = gomock.NewController(s.T())
|
s.mock = mocks.NewMockAutheliaCtx(s.T())
|
||||||
s.storageMock = mocks.NewMockStorage(s.ctrl)
|
s.mock.Ctx.Configuration.Regulation = schema.RegulationConfiguration{
|
||||||
s.ctx = context.Background()
|
|
||||||
|
|
||||||
s.config = schema.RegulationConfiguration{
|
|
||||||
MaxRetries: 3,
|
MaxRetries: 3,
|
||||||
BanTime: time.Second * 180,
|
BanTime: time.Second * 180,
|
||||||
FindTime: time.Second * 30,
|
FindTime: time.Second * 30,
|
||||||
}
|
}
|
||||||
s.clock.Set(time.Now())
|
|
||||||
|
s.mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, "127.0.0.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RegulatorSuite) TearDownTest() {
|
func (s *RegulatorSuite) TearDownTest() {
|
||||||
s.ctrl.Finish()
|
s.mock.Ctrl.Finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RegulatorSuite) TestShouldMark() {
|
||||||
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
|
s.mock.StorageMock.EXPECT().AppendAuthenticationLog(s.mock.Ctx, model.AuthenticationAttempt{
|
||||||
|
Time: s.mock.Clock.Now(),
|
||||||
|
Successful: true,
|
||||||
|
Banned: false,
|
||||||
|
Username: "john",
|
||||||
|
Type: "1fa",
|
||||||
|
RemoteIP: model.NewNullIP(net.ParseIP("127.0.0.1")),
|
||||||
|
RequestURI: "https://google.com",
|
||||||
|
RequestMethod: fasthttp.MethodGet,
|
||||||
|
})
|
||||||
|
|
||||||
|
s.NoError(regulator.Mark(s.mock.Ctx, true, false, "john", "https://google.com", fasthttp.MethodGet, "1fa"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RegulatorSuite) TestShouldHandleRegulateError() {
|
||||||
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
|
s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed"))
|
||||||
|
|
||||||
|
until, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
|
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(time.Time{}, until)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
||||||
|
@ -48,17 +71,17 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: true,
|
Successful: true,
|
||||||
Time: s.clock.Now().Add(-4 * time.Minute),
|
Time: s.mock.Clock.Now().Add(-4 * time.Minute),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,27 +92,27 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-1 * time.Second),
|
Time: s.mock.Clock.Now().Add(-1 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-90 * time.Second),
|
Time: s.mock.Clock.Now().Add(-90 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-180 * time.Second),
|
Time: s.mock.Clock.Now().Add(-180 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,32 +123,32 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-1 * time.Second),
|
Time: s.mock.Clock.Now().Add(-1 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-4 * time.Second),
|
Time: s.mock.Clock.Now().Add(-4 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-6 * time.Second),
|
Time: s.mock.Clock.Now().Add(-6 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-180 * time.Second),
|
Time: s.mock.Clock.Now().Add(-180 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,27 +161,27 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-31 * time.Second),
|
Time: s.mock.Clock.Now().Add(-31 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-34 * time.Second),
|
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-36 * time.Second),
|
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,22 +190,22 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-34 * time.Second),
|
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-36 * time.Second),
|
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,7 +214,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-14 * time.Second),
|
Time: s.mock.Clock.Now().Add(-14 * time.Second),
|
||||||
},
|
},
|
||||||
// more than 30 seconds elapsed between this auth and the preceding one.
|
// more than 30 seconds elapsed between this auth and the preceding one.
|
||||||
// In that case we don't need to regulate the user even though the number
|
// In that case we don't need to regulate the user even though the number
|
||||||
|
@ -199,22 +222,22 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-94 * time.Second),
|
Time: s.mock.Clock.Now().Add(-94 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-96 * time.Second),
|
Time: s.mock.Clock.Now().Add(-96 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,34 +246,34 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-90 * time.Second),
|
Time: s.mock.Clock.Now().Add(-90 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: true,
|
Successful: true,
|
||||||
Time: s.clock.Now().Add(-93 * time.Second),
|
Time: s.mock.Clock.Now().Add(-93 * time.Second),
|
||||||
},
|
},
|
||||||
// The user was almost banned but he did a successful attempt. Therefore, even if the next
|
// The user was almost banned but he did a successful attempt. Therefore, even if the next
|
||||||
// failure happens within FindTime, he should not be banned.
|
// failure happens within FindTime, he should not be banned.
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-94 * time.Second),
|
Time: s.mock.Clock.Now().Add(-94 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-96 * time.Second),
|
Time: s.mock.Clock.Now().Add(-96 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,22 +288,22 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-31 * time.Second),
|
Time: s.mock.Clock.Now().Add(-31 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-34 * time.Second),
|
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "john",
|
Username: "john",
|
||||||
Successful: false,
|
Successful: false,
|
||||||
Time: s.clock.Now().Add(-36 * time.Second),
|
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.storageMock.EXPECT().
|
s.mock.StorageMock.EXPECT().
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
// Check Disabled Functionality.
|
// Check Disabled Functionality.
|
||||||
|
@ -290,8 +313,8 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
||||||
BanTime: time.Second * 180,
|
BanTime: time.Second * 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(config, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
|
||||||
// Check Enabled Functionality.
|
// Check Enabled Functionality.
|
||||||
|
@ -301,7 +324,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
||||||
BanTime: time.Second * 180,
|
BanTime: time.Second * 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
regulator = regulation.NewRegulator(config, s.storageMock, &s.clock)
|
regulator = regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
|
||||||
_, err = regulator.Regulate(s.ctx, "john")
|
_, err = regulator.Regulate(s.mock.Ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ type Regulator struct {
|
||||||
|
|
||||||
config schema.RegulationConfiguration
|
config schema.RegulationConfiguration
|
||||||
|
|
||||||
storageProvider storage.RegulatorProvider
|
store storage.RegulatorProvider
|
||||||
|
|
||||||
clock utils.Clock
|
clock utils.Clock
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,59 +28,6 @@ const (
|
||||||
tableEncryption = "encryption"
|
tableEncryption = "encryption"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OAuth2SessionType represents the potential OAuth 2.0 session types.
|
|
||||||
type OAuth2SessionType int
|
|
||||||
|
|
||||||
// Representation of specific OAuth 2.0 session types.
|
|
||||||
const (
|
|
||||||
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
|
|
||||||
OAuth2SessionTypeAuthorizeCode
|
|
||||||
OAuth2SessionTypeOpenIDConnect
|
|
||||||
OAuth2SessionTypePAR
|
|
||||||
OAuth2SessionTypePKCEChallenge
|
|
||||||
OAuth2SessionTypeRefreshToken
|
|
||||||
)
|
|
||||||
|
|
||||||
// String returns a string representation of this OAuth2SessionType.
|
|
||||||
func (s OAuth2SessionType) String() string {
|
|
||||||
switch s {
|
|
||||||
case OAuth2SessionTypeAccessToken:
|
|
||||||
return "access token"
|
|
||||||
case OAuth2SessionTypeAuthorizeCode:
|
|
||||||
return "authorization code"
|
|
||||||
case OAuth2SessionTypeOpenIDConnect:
|
|
||||||
return "openid connect"
|
|
||||||
case OAuth2SessionTypePAR:
|
|
||||||
return "pushed authorization request context"
|
|
||||||
case OAuth2SessionTypePKCEChallenge:
|
|
||||||
return "pkce challenge"
|
|
||||||
case OAuth2SessionTypeRefreshToken:
|
|
||||||
return "refresh token"
|
|
||||||
default:
|
|
||||||
return "invalid"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Table returns the table name for this session type.
|
|
||||||
func (s OAuth2SessionType) Table() string {
|
|
||||||
switch s {
|
|
||||||
case OAuth2SessionTypeAccessToken:
|
|
||||||
return tableOAuth2AccessTokenSession
|
|
||||||
case OAuth2SessionTypeAuthorizeCode:
|
|
||||||
return tableOAuth2AuthorizeCodeSession
|
|
||||||
case OAuth2SessionTypeOpenIDConnect:
|
|
||||||
return tableOAuth2OpenIDConnectSession
|
|
||||||
case OAuth2SessionTypePAR:
|
|
||||||
return tableOAuth2PARContext
|
|
||||||
case OAuth2SessionTypePKCEChallenge:
|
|
||||||
return tableOAuth2PKCERequestSession
|
|
||||||
case OAuth2SessionTypeRefreshToken:
|
|
||||||
return tableOAuth2RefreshTokenSession
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
encryptionNameCheck = "check"
|
encryptionNameCheck = "check"
|
||||||
)
|
)
|
||||||
|
|
|
@ -93,3 +93,56 @@ func (r EncryptionValidationTableResult) ResultDescriptor() string {
|
||||||
|
|
||||||
return "SUCCESS"
|
return "SUCCESS"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuth2SessionType represents the potential OAuth 2.0 session types.
|
||||||
|
type OAuth2SessionType int
|
||||||
|
|
||||||
|
// Representation of specific OAuth 2.0 session types.
|
||||||
|
const (
|
||||||
|
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
|
||||||
|
OAuth2SessionTypeAuthorizeCode
|
||||||
|
OAuth2SessionTypeOpenIDConnect
|
||||||
|
OAuth2SessionTypePAR
|
||||||
|
OAuth2SessionTypePKCEChallenge
|
||||||
|
OAuth2SessionTypeRefreshToken
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns a string representation of this OAuth2SessionType.
|
||||||
|
func (s OAuth2SessionType) String() string {
|
||||||
|
switch s {
|
||||||
|
case OAuth2SessionTypeAccessToken:
|
||||||
|
return "access token"
|
||||||
|
case OAuth2SessionTypeAuthorizeCode:
|
||||||
|
return "authorization code"
|
||||||
|
case OAuth2SessionTypeOpenIDConnect:
|
||||||
|
return "openid connect"
|
||||||
|
case OAuth2SessionTypePAR:
|
||||||
|
return "pushed authorization request context"
|
||||||
|
case OAuth2SessionTypePKCEChallenge:
|
||||||
|
return "pkce challenge"
|
||||||
|
case OAuth2SessionTypeRefreshToken:
|
||||||
|
return "refresh token"
|
||||||
|
default:
|
||||||
|
return "invalid"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table returns the table name for this session type.
|
||||||
|
func (s OAuth2SessionType) Table() string {
|
||||||
|
switch s {
|
||||||
|
case OAuth2SessionTypeAccessToken:
|
||||||
|
return tableOAuth2AccessTokenSession
|
||||||
|
case OAuth2SessionTypeAuthorizeCode:
|
||||||
|
return tableOAuth2AuthorizeCodeSession
|
||||||
|
case OAuth2SessionTypeOpenIDConnect:
|
||||||
|
return tableOAuth2OpenIDConnectSession
|
||||||
|
case OAuth2SessionTypePAR:
|
||||||
|
return tableOAuth2PARContext
|
||||||
|
case OAuth2SessionTypePKCEChallenge:
|
||||||
|
return tableOAuth2PKCERequestSession
|
||||||
|
case OAuth2SessionTypeRefreshToken:
|
||||||
|
return tableOAuth2RefreshTokenSession
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncryptionValidationResult(t *testing.T) {
|
||||||
|
result := &EncryptionValidationResult{
|
||||||
|
InvalidCheckValue: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, result.Success())
|
||||||
|
assert.True(t, result.Checked())
|
||||||
|
|
||||||
|
result = &EncryptionValidationResult{
|
||||||
|
InvalidCheckValue: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, result.Success())
|
||||||
|
assert.True(t, result.Checked())
|
||||||
|
|
||||||
|
result = &EncryptionValidationResult{
|
||||||
|
InvalidCheckValue: false,
|
||||||
|
Tables: map[string]EncryptionValidationTableResult{
|
||||||
|
tableWebAuthnDevices: {
|
||||||
|
Invalid: 10,
|
||||||
|
Total: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.Equal(t, "FAILURE", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||||
|
|
||||||
|
assert.False(t, result.Success())
|
||||||
|
assert.True(t, result.Checked())
|
||||||
|
|
||||||
|
result = &EncryptionValidationResult{
|
||||||
|
InvalidCheckValue: false,
|
||||||
|
Tables: map[string]EncryptionValidationTableResult{
|
||||||
|
tableWebAuthnDevices: {
|
||||||
|
Error: fmt.Errorf("failed to check table"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, result.Success())
|
||||||
|
assert.False(t, result.Checked())
|
||||||
|
assert.Equal(t, "N/A", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||||
|
|
||||||
|
result = &EncryptionValidationResult{
|
||||||
|
InvalidCheckValue: false,
|
||||||
|
Tables: map[string]EncryptionValidationTableResult{
|
||||||
|
tableWebAuthnDevices: {
|
||||||
|
Total: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, result.Success())
|
||||||
|
assert.True(t, result.Checked())
|
||||||
|
assert.Equal(t, "SUCCESS", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuth2SessionType(t *testing.T) {
|
||||||
|
assert.Equal(t, "access token", OAuth2SessionTypeAccessToken.String())
|
||||||
|
assert.Equal(t, tableOAuth2AccessTokenSession, OAuth2SessionTypeAccessToken.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "authorization code", OAuth2SessionTypeAuthorizeCode.String())
|
||||||
|
assert.Equal(t, tableOAuth2AuthorizeCodeSession, OAuth2SessionTypeAuthorizeCode.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "openid connect", OAuth2SessionTypeOpenIDConnect.String())
|
||||||
|
assert.Equal(t, tableOAuth2OpenIDConnectSession, OAuth2SessionTypeOpenIDConnect.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "pushed authorization request context", OAuth2SessionTypePAR.String())
|
||||||
|
assert.Equal(t, tableOAuth2PARContext, OAuth2SessionTypePAR.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "pkce challenge", OAuth2SessionTypePKCEChallenge.String())
|
||||||
|
assert.Equal(t, tableOAuth2PKCERequestSession, OAuth2SessionTypePKCEChallenge.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "refresh token", OAuth2SessionTypeRefreshToken.String())
|
||||||
|
assert.Equal(t, tableOAuth2RefreshTokenSession, OAuth2SessionTypeRefreshToken.Table())
|
||||||
|
|
||||||
|
assert.Equal(t, "invalid", OAuth2SessionType(-1).String())
|
||||||
|
assert.Equal(t, "", OAuth2SessionType(-1).Table())
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
|
func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
|
||||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
var key = sha256.Sum256([]byte("the key"))
|
||||||
|
|
||||||
var secret = "the secret"
|
var secret = "the secret"
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
|
func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
|
||||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
var key = sha256.Sum256([]byte("the key"))
|
||||||
|
|
||||||
var secret = "the secret"
|
var secret = "the secret"
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) {
|
func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) {
|
||||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
var key = sha256.Sum256([]byte("the key"))
|
||||||
|
|
||||||
encryptedSecret := []byte("abc123")
|
encryptedSecret := []byte("abc123")
|
||||||
|
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewWriteCloser creates a new io.WriteCloser from an io.Writer.
|
|
||||||
func NewWriteCloser(wr io.Writer) io.WriteCloser {
|
|
||||||
return &WriteCloser{wr: wr}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteCloser is a io.Writer with an io.Closer.
|
|
||||||
type WriteCloser struct {
|
|
||||||
wr io.Writer
|
|
||||||
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write to the io.Writer.
|
|
||||||
func (w *WriteCloser) Write(p []byte) (n int, err error) {
|
|
||||||
if w.closed {
|
|
||||||
return -1, errors.New("already closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.wr.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the io.Closer.
|
|
||||||
func (w *WriteCloser) Close() error {
|
|
||||||
w.closed = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -8,6 +8,97 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestIsStringAbsURL(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
have string
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ShouldBeAbs",
|
||||||
|
"https://google.com",
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ShouldNotBeAbs",
|
||||||
|
"google.com",
|
||||||
|
"could not parse 'google.com' as a URL",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
theError := IsStringAbsURL(tc.have)
|
||||||
|
|
||||||
|
if tc.err == "" {
|
||||||
|
assert.NoError(t, theError)
|
||||||
|
} else {
|
||||||
|
assert.EqualError(t, theError, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsStringInSliceF(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
needle string
|
||||||
|
haystack []string
|
||||||
|
isEqual func(needle, item string) bool
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ShouldBePresent",
|
||||||
|
"good",
|
||||||
|
[]string{"good"},
|
||||||
|
func(needle, item string) bool {
|
||||||
|
return needle == item
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ShouldNotBePresent",
|
||||||
|
"bad",
|
||||||
|
[]string{"good"},
|
||||||
|
func(needle, item string) bool {
|
||||||
|
return needle == item
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.expected, IsStringInSliceF(tc.needle, tc.haystack, tc.isEqual))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringHTMLEscape(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
have string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ShouldNotAlterAlphaNum",
|
||||||
|
"abc123",
|
||||||
|
"abc123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ShouldEscapeSpecial",
|
||||||
|
"abc123><@#@",
|
||||||
|
"abc123><@#@",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.expected, StringHTMLEscape(tc.have))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStringSplitDelimitedEscaped(t *testing.T) {
|
func TestStringSplitDelimitedEscaped(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc, have string
|
desc, have string
|
||||||
|
|
|
@ -209,11 +209,51 @@ func TestParseTimeString(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
index, actual, err := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
index, actualA, errA := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||||
|
actualB, errB := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||||
|
actualC, errC := ParseTimeString(tc.have)
|
||||||
|
|
||||||
|
if tc.err == "" {
|
||||||
|
assert.NoError(t, errA)
|
||||||
|
assert.NoError(t, errB)
|
||||||
|
assert.NoError(t, errC)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.index, index)
|
||||||
|
assert.Equal(t, tc.expected.UnixNano(), actualA.UnixNano())
|
||||||
|
assert.Equal(t, tc.expected.UnixNano(), actualB.UnixNano())
|
||||||
|
assert.Equal(t, tc.expected.UnixNano(), actualC.UnixNano())
|
||||||
|
} else {
|
||||||
|
assert.EqualError(t, errA, tc.err)
|
||||||
|
assert.EqualError(t, errB, tc.err)
|
||||||
|
assert.EqualError(t, errC, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTimeStringWithLayouts(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
have string
|
||||||
|
index int
|
||||||
|
expected time.Time
|
||||||
|
err string
|
||||||
|
}{
|
||||||
|
{"ShouldParseIntegerAsUnix", "1675899060", -1, time.Unix(1675899060, 0), ""},
|
||||||
|
{"ShouldParseIntegerAsUnixMilli", "1675899060000", -2, time.Unix(1675899060, 0), ""},
|
||||||
|
{"ShouldParseIntegerAsUnixMicro", "1675899060000000", -3, time.Unix(1675899060, 0), ""},
|
||||||
|
{"ShouldNotParseSuperLargeInteger", "9999999999999999999999999999999999999999", -999, time.Unix(0, 0), "time value was detected as an integer but the integer could not be parsed: strconv.ParseInt: parsing \"9999999999999999999999999999999999999999\": value out of range"},
|
||||||
|
{"ShouldParseSimpleTime", "Jan 2 15:04:05 2006", 0, time.Unix(1136214245, 0), ""},
|
||||||
|
{"ShouldNotParseInvalidTime", "abc", -998, time.Unix(0, 0), "failed to find a suitable time layout for time 'abc'"},
|
||||||
|
{"ShouldMatchDate", "2020-05-01", 6, time.Unix(1588291200, 0), ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
actual, err := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||||
|
|
||||||
if tc.err == "" {
|
if tc.err == "" {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tc.index, index)
|
|
||||||
assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano())
|
assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano())
|
||||||
} else {
|
} else {
|
||||||
assert.EqualError(t, err, tc.err)
|
assert.EqualError(t, err, tc.err)
|
||||||
|
|
|
@ -59,3 +59,8 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
|
||||||
assert.False(t, isURLSafe("https://secure.example.comc", "example.com"))
|
assert.False(t, isURLSafe("https://secure.example.comc", "example.com"))
|
||||||
assert.False(t, isURLSafe("https://secure.example.co", "example.com"))
|
assert.False(t, isURLSafe("https://secure.example.co", "example.com"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasDomainSuffix(t *testing.T) {
|
||||||
|
assert.False(t, HasDomainSuffix("abc", ""))
|
||||||
|
assert.False(t, HasDomainSuffix("", ""))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue