test: add misc missing tests (#5479)
Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>pull/5481/head
parent
e784a72735
commit
2e8a460a66
|
@ -0,0 +1,20 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPrometheus(t *testing.T) {
|
||||
p := NewPrometheus()
|
||||
|
||||
assert.NotNil(t, p)
|
||||
|
||||
p.RecordRequest("400", "GET", time.Second)
|
||||
p.RecordAuthz("400")
|
||||
p.RecordAuthn(true, false, "WebAuthn")
|
||||
p.RecordAuthn(true, false, "1fa")
|
||||
p.RecordAuthenticationDuration(true, time.Second)
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCryptographical(t *testing.T) {
|
||||
p := &Cryptographical{}
|
||||
|
||||
data := make([]byte, 10)
|
||||
|
||||
n, err := p.Read(data)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data2, err := p.BytesErr()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
data2 = p.Bytes()
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
|
||||
assert.Len(t, data2, 74)
|
||||
|
||||
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 76)
|
||||
|
||||
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
strdata := p.StringCustom(10, CharSetAlphabetic)
|
||||
assert.Len(t, strdata, 10)
|
||||
|
||||
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, strdata, 11)
|
||||
|
||||
i := p.Intn(999)
|
||||
assert.Greater(t, i, 0)
|
||||
assert.Less(t, i, 999)
|
||||
|
||||
i, err = p.IntnErr(999)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, i, 0)
|
||||
assert.Less(t, i, 999)
|
||||
|
||||
i, err = p.IntnErr(-4)
|
||||
assert.EqualError(t, err, "n must be more than 0")
|
||||
assert.Equal(t, 0, i)
|
||||
|
||||
bi := p.Int(big.NewInt(999))
|
||||
assert.Greater(t, bi.Int64(), int64(0))
|
||||
assert.Less(t, bi.Int64(), int64(999))
|
||||
|
||||
bi = p.Int(nil)
|
||||
assert.Equal(t, int64(-1), bi.Int64())
|
||||
|
||||
bi, err = p.IntErr(nil)
|
||||
assert.Nil(t, bi)
|
||||
assert.EqualError(t, err, "max is required")
|
||||
|
||||
bi, err = p.IntErr(big.NewInt(-1))
|
||||
assert.Nil(t, bi)
|
||||
assert.EqualError(t, err, "max must be 1 or more")
|
||||
|
||||
prime, err := p.Prime(64)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, prime)
|
||||
}
|
|
@ -112,6 +112,10 @@ func (r *Mathematical) Intn(n int) int {
|
|||
|
||||
// IntnErr returns a random int error combination with a maximum of n.
|
||||
func (r *Mathematical) IntnErr(n int) (output int, err error) {
|
||||
if n <= 0 {
|
||||
return 0, fmt.Errorf("n must be more than 0")
|
||||
}
|
||||
|
||||
return r.Intn(n), nil
|
||||
}
|
||||
|
||||
|
@ -132,15 +136,11 @@ func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) {
|
|||
return nil, fmt.Errorf("max is required")
|
||||
}
|
||||
|
||||
if max.Sign() <= 0 {
|
||||
if max.Int64() <= 0 {
|
||||
return nil, fmt.Errorf("max must be 1 or more")
|
||||
}
|
||||
|
||||
r.lock.Lock()
|
||||
|
||||
defer r.lock.Unlock()
|
||||
|
||||
return big.NewInt(int64(r.Intn(max.Sign()))), nil
|
||||
return big.NewInt(int64(r.Intn(int(max.Int64())))), nil
|
||||
}
|
||||
|
||||
// Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMathematical(t *testing.T) {
|
||||
p := NewMathematical()
|
||||
|
||||
data := make([]byte, 10)
|
||||
|
||||
n, err := p.Read(data)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data2, err := p.BytesErr()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
data2 = p.Bytes()
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
|
||||
assert.Len(t, data2, 74)
|
||||
|
||||
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 76)
|
||||
|
||||
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, data2, 72)
|
||||
|
||||
strdata := p.StringCustom(10, CharSetAlphabetic)
|
||||
assert.Len(t, strdata, 10)
|
||||
|
||||
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, strdata, 11)
|
||||
|
||||
i := p.Intn(999)
|
||||
assert.Greater(t, i, 0)
|
||||
assert.Less(t, i, 999)
|
||||
|
||||
i, err = p.IntnErr(999)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, i, 0)
|
||||
assert.Less(t, i, 999)
|
||||
|
||||
i, err = p.IntnErr(-4)
|
||||
assert.EqualError(t, err, "n must be more than 0")
|
||||
assert.Equal(t, 0, i)
|
||||
|
||||
bi := p.Int(big.NewInt(999))
|
||||
assert.Greater(t, bi.Int64(), int64(0))
|
||||
assert.Less(t, bi.Int64(), int64(999))
|
||||
|
||||
bi = p.Int(nil)
|
||||
assert.Equal(t, int64(-1), bi.Int64())
|
||||
|
||||
bi, err = p.IntErr(nil)
|
||||
assert.Nil(t, bi)
|
||||
assert.EqualError(t, err, "max is required")
|
||||
|
||||
bi, err = p.IntErr(big.NewInt(-1))
|
||||
assert.Nil(t, bi)
|
||||
assert.EqualError(t, err, "max must be 1 or more")
|
||||
|
||||
prime, err := p.Prime(64)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, prime)
|
||||
}
|
|
@ -12,12 +12,12 @@ import (
|
|||
)
|
||||
|
||||
// NewRegulator create a regulator instance.
|
||||
func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
||||
func NewRegulator(config schema.RegulationConfiguration, store storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
||||
return &Regulator{
|
||||
enabled: config.MaxRetries > 0,
|
||||
storageProvider: provider,
|
||||
clock: clock,
|
||||
config: config,
|
||||
enabled: config.MaxRetries > 0,
|
||||
store: store,
|
||||
clock: clock,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,7 +26,7 @@ func NewRegulator(config schema.RegulationConfiguration, provider storage.Regula
|
|||
func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error {
|
||||
ctx.RecordAuthn(successful, banned, strings.ToLower(authType))
|
||||
|
||||
return r.storageProvider.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{
|
||||
return r.store.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{
|
||||
Time: r.clock.Now(),
|
||||
Successful: successful,
|
||||
Banned: banned,
|
||||
|
@ -46,7 +46,7 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
|
|||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
|
||||
attempts, err := r.store.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
|
||||
if err != nil {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
|
|
@ -1,46 +1,69 @@
|
|||
package regulation_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
"github.com/authelia/authelia/v4/internal/model"
|
||||
"github.com/authelia/authelia/v4/internal/regulation"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
type RegulatorSuite struct {
|
||||
suite.Suite
|
||||
|
||||
ctx context.Context
|
||||
ctrl *gomock.Controller
|
||||
storageMock *mocks.MockStorage
|
||||
config schema.RegulationConfiguration
|
||||
clock utils.TestingClock
|
||||
mock *mocks.MockAutheliaCtx
|
||||
}
|
||||
|
||||
func (s *RegulatorSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storageMock = mocks.NewMockStorage(s.ctrl)
|
||||
s.ctx = context.Background()
|
||||
|
||||
s.config = schema.RegulationConfiguration{
|
||||
s.mock = mocks.NewMockAutheliaCtx(s.T())
|
||||
s.mock.Ctx.Configuration.Regulation = schema.RegulationConfiguration{
|
||||
MaxRetries: 3,
|
||||
BanTime: time.Second * 180,
|
||||
FindTime: time.Second * 30,
|
||||
}
|
||||
s.clock.Set(time.Now())
|
||||
|
||||
s.mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, "127.0.0.1")
|
||||
}
|
||||
|
||||
func (s *RegulatorSuite) TearDownTest() {
|
||||
s.ctrl.Finish()
|
||||
s.mock.Ctrl.Finish()
|
||||
}
|
||||
|
||||
func (s *RegulatorSuite) TestShouldMark() {
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
s.mock.StorageMock.EXPECT().AppendAuthenticationLog(s.mock.Ctx, model.AuthenticationAttempt{
|
||||
Time: s.mock.Clock.Now(),
|
||||
Successful: true,
|
||||
Banned: false,
|
||||
Username: "john",
|
||||
Type: "1fa",
|
||||
RemoteIP: model.NewNullIP(net.ParseIP("127.0.0.1")),
|
||||
RequestURI: "https://google.com",
|
||||
RequestMethod: fasthttp.MethodGet,
|
||||
})
|
||||
|
||||
s.NoError(regulator.Mark(s.mock.Ctx, true, false, "john", "https://google.com", fasthttp.MethodGet, "1fa"))
|
||||
}
|
||||
|
||||
func (s *RegulatorSuite) TestShouldHandleRegulateError() {
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed"))
|
||||
|
||||
until, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(time.Time{}, until)
|
||||
}
|
||||
|
||||
func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
||||
|
@ -48,17 +71,17 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: true,
|
||||
Time: s.clock.Now().Add(-4 * time.Minute),
|
||||
Time: s.mock.Clock.Now().Add(-4 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -69,27 +92,27 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-1 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-1 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-90 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-90 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-180 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-180 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -100,32 +123,32 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-1 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-1 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-4 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-4 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-6 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-6 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-180 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-180 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
||||
|
@ -138,27 +161,27 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-31 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-31 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-34 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-36 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
||||
|
@ -167,22 +190,22 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-34 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-36 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -191,7 +214,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-14 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-14 * time.Second),
|
||||
},
|
||||
// more than 30 seconds elapsed between this auth and the preceding one.
|
||||
// In that case we don't need to regulate the user even though the number
|
||||
|
@ -199,22 +222,22 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-94 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-94 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-96 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-96 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -223,34 +246,34 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-90 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-90 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: true,
|
||||
Time: s.clock.Now().Add(-93 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-93 * time.Second),
|
||||
},
|
||||
// The user was almost banned but he did a successful attempt. Therefore, even if the next
|
||||
// failure happens within FindTime, he should not be banned.
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-94 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-94 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-96 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-96 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
|
||||
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -265,22 +288,22 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-31 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-31 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-34 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-34 * time.Second),
|
||||
},
|
||||
{
|
||||
Username: "john",
|
||||
Successful: false,
|
||||
Time: s.clock.Now().Add(-36 * time.Second),
|
||||
Time: s.mock.Clock.Now().Add(-36 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
s.mock.StorageMock.EXPECT().
|
||||
LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
// Check Disabled Functionality.
|
||||
|
@ -290,8 +313,8 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
BanTime: time.Second * 180,
|
||||
}
|
||||
|
||||
regulator := regulation.NewRegulator(config, s.storageMock, &s.clock)
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
regulator := regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
|
||||
_, err := regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
|
||||
// Check Enabled Functionality.
|
||||
|
@ -301,7 +324,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
BanTime: time.Second * 180,
|
||||
}
|
||||
|
||||
regulator = regulation.NewRegulator(config, s.storageMock, &s.clock)
|
||||
_, err = regulator.Regulate(s.ctx, "john")
|
||||
regulator = regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
|
||||
_, err = regulator.Regulate(s.mock.Ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ type Regulator struct {
|
|||
|
||||
config schema.RegulationConfiguration
|
||||
|
||||
storageProvider storage.RegulatorProvider
|
||||
store storage.RegulatorProvider
|
||||
|
||||
clock utils.Clock
|
||||
}
|
||||
|
|
|
@ -28,59 +28,6 @@ const (
|
|||
tableEncryption = "encryption"
|
||||
)
|
||||
|
||||
// OAuth2SessionType represents the potential OAuth 2.0 session types.
|
||||
type OAuth2SessionType int
|
||||
|
||||
// Representation of specific OAuth 2.0 session types.
|
||||
const (
|
||||
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
|
||||
OAuth2SessionTypeAuthorizeCode
|
||||
OAuth2SessionTypeOpenIDConnect
|
||||
OAuth2SessionTypePAR
|
||||
OAuth2SessionTypePKCEChallenge
|
||||
OAuth2SessionTypeRefreshToken
|
||||
)
|
||||
|
||||
// String returns a string representation of this OAuth2SessionType.
|
||||
func (s OAuth2SessionType) String() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return "access token"
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return "authorization code"
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return "openid connect"
|
||||
case OAuth2SessionTypePAR:
|
||||
return "pushed authorization request context"
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return "pkce challenge"
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return "refresh token"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// Table returns the table name for this session type.
|
||||
func (s OAuth2SessionType) Table() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return tableOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return tableOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return tableOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePAR:
|
||||
return tableOAuth2PARContext
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return tableOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return tableOAuth2RefreshTokenSession
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
encryptionNameCheck = "check"
|
||||
)
|
||||
|
|
|
@ -93,3 +93,56 @@ func (r EncryptionValidationTableResult) ResultDescriptor() string {
|
|||
|
||||
return "SUCCESS"
|
||||
}
|
||||
|
||||
// OAuth2SessionType represents the potential OAuth 2.0 session types.
|
||||
type OAuth2SessionType int
|
||||
|
||||
// Representation of specific OAuth 2.0 session types.
|
||||
const (
|
||||
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
|
||||
OAuth2SessionTypeAuthorizeCode
|
||||
OAuth2SessionTypeOpenIDConnect
|
||||
OAuth2SessionTypePAR
|
||||
OAuth2SessionTypePKCEChallenge
|
||||
OAuth2SessionTypeRefreshToken
|
||||
)
|
||||
|
||||
// String returns a string representation of this OAuth2SessionType.
|
||||
func (s OAuth2SessionType) String() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return "access token"
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return "authorization code"
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return "openid connect"
|
||||
case OAuth2SessionTypePAR:
|
||||
return "pushed authorization request context"
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return "pkce challenge"
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return "refresh token"
|
||||
default:
|
||||
return "invalid"
|
||||
}
|
||||
}
|
||||
|
||||
// Table returns the table name for this session type.
|
||||
func (s OAuth2SessionType) Table() string {
|
||||
switch s {
|
||||
case OAuth2SessionTypeAccessToken:
|
||||
return tableOAuth2AccessTokenSession
|
||||
case OAuth2SessionTypeAuthorizeCode:
|
||||
return tableOAuth2AuthorizeCodeSession
|
||||
case OAuth2SessionTypeOpenIDConnect:
|
||||
return tableOAuth2OpenIDConnectSession
|
||||
case OAuth2SessionTypePAR:
|
||||
return tableOAuth2PARContext
|
||||
case OAuth2SessionTypePKCEChallenge:
|
||||
return tableOAuth2PKCERequestSession
|
||||
case OAuth2SessionTypeRefreshToken:
|
||||
return tableOAuth2RefreshTokenSession
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEncryptionValidationResult(t *testing.T) {
|
||||
result := &EncryptionValidationResult{
|
||||
InvalidCheckValue: false,
|
||||
}
|
||||
|
||||
assert.True(t, result.Success())
|
||||
assert.True(t, result.Checked())
|
||||
|
||||
result = &EncryptionValidationResult{
|
||||
InvalidCheckValue: true,
|
||||
}
|
||||
|
||||
assert.False(t, result.Success())
|
||||
assert.True(t, result.Checked())
|
||||
|
||||
result = &EncryptionValidationResult{
|
||||
InvalidCheckValue: false,
|
||||
Tables: map[string]EncryptionValidationTableResult{
|
||||
tableWebAuthnDevices: {
|
||||
Invalid: 10,
|
||||
Total: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, "FAILURE", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||
|
||||
assert.False(t, result.Success())
|
||||
assert.True(t, result.Checked())
|
||||
|
||||
result = &EncryptionValidationResult{
|
||||
InvalidCheckValue: false,
|
||||
Tables: map[string]EncryptionValidationTableResult{
|
||||
tableWebAuthnDevices: {
|
||||
Error: fmt.Errorf("failed to check table"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.False(t, result.Success())
|
||||
assert.False(t, result.Checked())
|
||||
assert.Equal(t, "N/A", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||
|
||||
result = &EncryptionValidationResult{
|
||||
InvalidCheckValue: false,
|
||||
Tables: map[string]EncryptionValidationTableResult{
|
||||
tableWebAuthnDevices: {
|
||||
Total: 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, result.Success())
|
||||
assert.True(t, result.Checked())
|
||||
assert.Equal(t, "SUCCESS", result.Tables[tableWebAuthnDevices].ResultDescriptor())
|
||||
}
|
||||
|
||||
func TestOAuth2SessionType(t *testing.T) {
|
||||
assert.Equal(t, "access token", OAuth2SessionTypeAccessToken.String())
|
||||
assert.Equal(t, tableOAuth2AccessTokenSession, OAuth2SessionTypeAccessToken.Table())
|
||||
|
||||
assert.Equal(t, "authorization code", OAuth2SessionTypeAuthorizeCode.String())
|
||||
assert.Equal(t, tableOAuth2AuthorizeCodeSession, OAuth2SessionTypeAuthorizeCode.Table())
|
||||
|
||||
assert.Equal(t, "openid connect", OAuth2SessionTypeOpenIDConnect.String())
|
||||
assert.Equal(t, tableOAuth2OpenIDConnectSession, OAuth2SessionTypeOpenIDConnect.Table())
|
||||
|
||||
assert.Equal(t, "pushed authorization request context", OAuth2SessionTypePAR.String())
|
||||
assert.Equal(t, tableOAuth2PARContext, OAuth2SessionTypePAR.Table())
|
||||
|
||||
assert.Equal(t, "pkce challenge", OAuth2SessionTypePKCEChallenge.String())
|
||||
assert.Equal(t, tableOAuth2PKCERequestSession, OAuth2SessionTypePKCEChallenge.Table())
|
||||
|
||||
assert.Equal(t, "refresh token", OAuth2SessionTypeRefreshToken.String())
|
||||
assert.Equal(t, tableOAuth2RefreshTokenSession, OAuth2SessionTypeRefreshToken.Table())
|
||||
|
||||
assert.Equal(t, "invalid", OAuth2SessionType(-1).String())
|
||||
assert.Equal(t, "", OAuth2SessionType(-1).Table())
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
|
||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
||||
var key = sha256.Sum256([]byte("the key"))
|
||||
|
||||
var secret = "the secret"
|
||||
|
||||
|
@ -22,7 +22,7 @@ func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
|
||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
||||
var key = sha256.Sum256([]byte("the key"))
|
||||
|
||||
var secret = "the secret"
|
||||
|
||||
|
@ -37,7 +37,7 @@ func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) {
|
||||
var key [32]byte = sha256.Sum256([]byte("the key"))
|
||||
var key = sha256.Sum256([]byte("the key"))
|
||||
|
||||
encryptedSecret := []byte("abc123")
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// NewWriteCloser creates a new io.WriteCloser from an io.Writer.
|
||||
func NewWriteCloser(wr io.Writer) io.WriteCloser {
|
||||
return &WriteCloser{wr: wr}
|
||||
}
|
||||
|
||||
// WriteCloser is a io.Writer with an io.Closer.
|
||||
type WriteCloser struct {
|
||||
wr io.Writer
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Write to the io.Writer.
|
||||
func (w *WriteCloser) Write(p []byte) (n int, err error) {
|
||||
if w.closed {
|
||||
return -1, errors.New("already closed")
|
||||
}
|
||||
|
||||
return w.wr.Write(p)
|
||||
}
|
||||
|
||||
// Close the io.Closer.
|
||||
func (w *WriteCloser) Close() error {
|
||||
w.closed = true
|
||||
|
||||
return nil
|
||||
}
|
|
@ -8,6 +8,97 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsStringAbsURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"ShouldBeAbs",
|
||||
"https://google.com",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"ShouldNotBeAbs",
|
||||
"google.com",
|
||||
"could not parse 'google.com' as a URL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
theError := IsStringAbsURL(tc.have)
|
||||
|
||||
if tc.err == "" {
|
||||
assert.NoError(t, theError)
|
||||
} else {
|
||||
assert.EqualError(t, theError, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsStringInSliceF(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
needle string
|
||||
haystack []string
|
||||
isEqual func(needle, item string) bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"ShouldBePresent",
|
||||
"good",
|
||||
[]string{"good"},
|
||||
func(needle, item string) bool {
|
||||
return needle == item
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"ShouldNotBePresent",
|
||||
"bad",
|
||||
[]string{"good"},
|
||||
func(needle, item string) bool {
|
||||
return needle == item
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, IsStringInSliceF(tc.needle, tc.haystack, tc.isEqual))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringHTMLEscape(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"ShouldNotAlterAlphaNum",
|
||||
"abc123",
|
||||
"abc123",
|
||||
},
|
||||
{
|
||||
"ShouldEscapeSpecial",
|
||||
"abc123><@#@",
|
||||
"abc123><@#@",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, StringHTMLEscape(tc.have))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSplitDelimitedEscaped(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc, have string
|
||||
|
|
|
@ -209,11 +209,51 @@ func TestParseTimeString(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
index, actual, err := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||
index, actualA, errA := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||
actualB, errB := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||
actualC, errC := ParseTimeString(tc.have)
|
||||
|
||||
if tc.err == "" {
|
||||
assert.NoError(t, errA)
|
||||
assert.NoError(t, errB)
|
||||
assert.NoError(t, errC)
|
||||
|
||||
assert.Equal(t, tc.index, index)
|
||||
assert.Equal(t, tc.expected.UnixNano(), actualA.UnixNano())
|
||||
assert.Equal(t, tc.expected.UnixNano(), actualB.UnixNano())
|
||||
assert.Equal(t, tc.expected.UnixNano(), actualC.UnixNano())
|
||||
} else {
|
||||
assert.EqualError(t, errA, tc.err)
|
||||
assert.EqualError(t, errB, tc.err)
|
||||
assert.EqualError(t, errC, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTimeStringWithLayouts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
have string
|
||||
index int
|
||||
expected time.Time
|
||||
err string
|
||||
}{
|
||||
{"ShouldParseIntegerAsUnix", "1675899060", -1, time.Unix(1675899060, 0), ""},
|
||||
{"ShouldParseIntegerAsUnixMilli", "1675899060000", -2, time.Unix(1675899060, 0), ""},
|
||||
{"ShouldParseIntegerAsUnixMicro", "1675899060000000", -3, time.Unix(1675899060, 0), ""},
|
||||
{"ShouldNotParseSuperLargeInteger", "9999999999999999999999999999999999999999", -999, time.Unix(0, 0), "time value was detected as an integer but the integer could not be parsed: strconv.ParseInt: parsing \"9999999999999999999999999999999999999999\": value out of range"},
|
||||
{"ShouldParseSimpleTime", "Jan 2 15:04:05 2006", 0, time.Unix(1136214245, 0), ""},
|
||||
{"ShouldNotParseInvalidTime", "abc", -998, time.Unix(0, 0), "failed to find a suitable time layout for time 'abc'"},
|
||||
{"ShouldMatchDate", "2020-05-01", 6, time.Unix(1588291200, 0), ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual, err := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
|
||||
|
||||
if tc.err == "" {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.index, index)
|
||||
assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano())
|
||||
} else {
|
||||
assert.EqualError(t, err, tc.err)
|
||||
|
|
|
@ -59,3 +59,8 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
|
|||
assert.False(t, isURLSafe("https://secure.example.comc", "example.com"))
|
||||
assert.False(t, isURLSafe("https://secure.example.co", "example.com"))
|
||||
}
|
||||
|
||||
func TestHasDomainSuffix(t *testing.T) {
|
||||
assert.False(t, HasDomainSuffix("abc", ""))
|
||||
assert.False(t, HasDomainSuffix("", ""))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue