From 2e8a460a666d164bd8b9b6c646b44515540138bc Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 24 May 2023 22:33:05 +1000 Subject: [PATCH] test: add misc missing tests (#5479) Signed-off-by: James Elliott --- internal/metrics/prometheus_test.go | 20 +++ internal/random/cryptographical_test.go | 75 +++++++++++ internal/random/mathematical.go | 12 +- internal/random/mathematical_test.go | 75 +++++++++++ internal/regulation/regulator.go | 14 +- internal/regulation/regulator_test.go | 165 ++++++++++++++---------- internal/regulation/types.go | 2 +- internal/storage/const.go | 53 -------- internal/storage/types.go | 53 ++++++++ internal/storage/types_test.go | 87 +++++++++++++ internal/utils/aes_test.go | 6 +- internal/utils/io.go | 34 ----- internal/utils/strings_test.go | 91 +++++++++++++ internal/utils/time_test.go | 44 ++++++- internal/utils/url_test.go | 5 + 15 files changed, 559 insertions(+), 177 deletions(-) create mode 100644 internal/metrics/prometheus_test.go create mode 100644 internal/random/cryptographical_test.go create mode 100644 internal/random/mathematical_test.go create mode 100644 internal/storage/types_test.go delete mode 100644 internal/utils/io.go diff --git a/internal/metrics/prometheus_test.go b/internal/metrics/prometheus_test.go new file mode 100644 index 000000000..fd0220fbc --- /dev/null +++ b/internal/metrics/prometheus_test.go @@ -0,0 +1,20 @@ +package metrics + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewPrometheus(t *testing.T) { + p := NewPrometheus() + + assert.NotNil(t, p) + + p.RecordRequest("400", "GET", time.Second) + p.RecordAuthz("400") + p.RecordAuthn(true, false, "WebAuthn") + p.RecordAuthn(true, false, "1fa") + p.RecordAuthenticationDuration(true, time.Second) +} diff --git a/internal/random/cryptographical_test.go b/internal/random/cryptographical_test.go new file mode 100644 index 000000000..afb4ab29d --- /dev/null +++ b/internal/random/cryptographical_test.go @@ -0,0 +1,75 @@ +package random + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCryptographical(t *testing.T) { + p := &Cryptographical{} + + data := make([]byte, 10) + + n, err := p.Read(data) + assert.Equal(t, 10, n) + assert.NoError(t, err) + + data2, err := p.BytesErr() + assert.NoError(t, err) + assert.Len(t, data2, 72) + + data2 = p.Bytes() + assert.Len(t, data2, 72) + + data2 = p.BytesCustom(74, []byte(CharSetAlphabetic)) + assert.Len(t, data2, 74) + + data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic)) + assert.NoError(t, err) + assert.Len(t, data2, 76) + + data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic)) + assert.NoError(t, err) + assert.Len(t, data2, 72) + + strdata := p.StringCustom(10, CharSetAlphabetic) + assert.Len(t, strdata, 10) + + strdata, err = p.StringCustomErr(11, CharSetAlphabetic) + assert.NoError(t, err) + assert.Len(t, strdata, 11) + + i := p.Intn(999) + assert.Greater(t, i, 0) + assert.Less(t, i, 999) + + i, err = p.IntnErr(999) + assert.NoError(t, err) + assert.Greater(t, i, 0) + assert.Less(t, i, 999) + + i, err = p.IntnErr(-4) + assert.EqualError(t, err, "n must be more than 0") + assert.Equal(t, 0, i) + + bi := p.Int(big.NewInt(999)) + assert.Greater(t, bi.Int64(), int64(0)) + assert.Less(t, bi.Int64(), int64(999)) + + bi = p.Int(nil) + assert.Equal(t, int64(-1), bi.Int64()) + + bi, err = p.IntErr(nil) + assert.Nil(t, bi) + assert.EqualError(t, err, "max is required") + + bi, err = p.IntErr(big.NewInt(-1)) + assert.Nil(t, bi) + assert.EqualError(t, err, "max must be 1 or more") + + prime, err := p.Prime(64) + assert.NoError(t, err) + assert.NotNil(t, prime) +} diff --git a/internal/random/mathematical.go b/internal/random/mathematical.go index 24b8e34d3..a4498e96b 100644 --- a/internal/random/mathematical.go +++ b/internal/random/mathematical.go @@ -112,6 +112,10 @@ func (r *Mathematical) Intn(n int) int { // IntnErr returns a random int error combination with a maximum of n. func (r *Mathematical) IntnErr(n int) (output int, err error) { + if n <= 0 { + return 0, fmt.Errorf("n must be more than 0") + } + return r.Intn(n), nil } @@ -132,15 +136,11 @@ func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) { return nil, fmt.Errorf("max is required") } - if max.Sign() <= 0 { + if max.Int64() <= 0 { return nil, fmt.Errorf("max must be 1 or more") } - r.lock.Lock() - - defer r.lock.Unlock() - - return big.NewInt(int64(r.Intn(max.Sign()))), nil + return big.NewInt(int64(r.Intn(int(max.Int64())))), nil } // Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any diff --git a/internal/random/mathematical_test.go b/internal/random/mathematical_test.go new file mode 100644 index 000000000..b2f296dfa --- /dev/null +++ b/internal/random/mathematical_test.go @@ -0,0 +1,75 @@ +package random + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMathematical(t *testing.T) { + p := NewMathematical() + + data := make([]byte, 10) + + n, err := p.Read(data) + assert.Equal(t, 10, n) + assert.NoError(t, err) + + data2, err := p.BytesErr() + assert.NoError(t, err) + assert.Len(t, data2, 72) + + data2 = p.Bytes() + assert.Len(t, data2, 72) + + data2 = p.BytesCustom(74, []byte(CharSetAlphabetic)) + assert.Len(t, data2, 74) + + data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic)) + assert.NoError(t, err) + assert.Len(t, data2, 76) + + data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic)) + assert.NoError(t, err) + assert.Len(t, data2, 72) + + strdata := p.StringCustom(10, CharSetAlphabetic) + assert.Len(t, strdata, 10) + + strdata, err = p.StringCustomErr(11, CharSetAlphabetic) + assert.NoError(t, err) + assert.Len(t, strdata, 11) + + i := p.Intn(999) + assert.Greater(t, i, 0) + assert.Less(t, i, 999) + + i, err = p.IntnErr(999) + assert.NoError(t, err) + assert.Greater(t, i, 0) + assert.Less(t, i, 999) + + i, err = p.IntnErr(-4) + assert.EqualError(t, err, "n must be more than 0") + assert.Equal(t, 0, i) + + bi := p.Int(big.NewInt(999)) + assert.Greater(t, bi.Int64(), int64(0)) + assert.Less(t, bi.Int64(), int64(999)) + + bi = p.Int(nil) + assert.Equal(t, int64(-1), bi.Int64()) + + bi, err = p.IntErr(nil) + assert.Nil(t, bi) + assert.EqualError(t, err, "max is required") + + bi, err = p.IntErr(big.NewInt(-1)) + assert.Nil(t, bi) + assert.EqualError(t, err, "max must be 1 or more") + + prime, err := p.Prime(64) + assert.NoError(t, err) + assert.NotNil(t, prime) +} diff --git a/internal/regulation/regulator.go b/internal/regulation/regulator.go index c57cabdcd..ab1995483 100644 --- a/internal/regulation/regulator.go +++ b/internal/regulation/regulator.go @@ -12,12 +12,12 @@ import ( ) // NewRegulator create a regulator instance. -func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { +func NewRegulator(config schema.RegulationConfiguration, store storage.RegulatorProvider, clock utils.Clock) *Regulator { return &Regulator{ - enabled: config.MaxRetries > 0, - storageProvider: provider, - clock: clock, - config: config, + enabled: config.MaxRetries > 0, + store: store, + clock: clock, + config: config, } } @@ -26,7 +26,7 @@ func NewRegulator(config schema.RegulationConfiguration, provider storage.Regula func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error { ctx.RecordAuthn(successful, banned, strings.ToLower(authType)) - return r.storageProvider.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{ + return r.store.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{ Time: r.clock.Now(), Successful: successful, Banned: banned, @@ -46,7 +46,7 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e return time.Time{}, nil } - attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0) + attempts, err := r.store.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0) if err != nil { return time.Time{}, nil } diff --git a/internal/regulation/regulator_test.go b/internal/regulation/regulator_test.go index 6b61be4da..a7156dbd0 100644 --- a/internal/regulation/regulator_test.go +++ b/internal/regulation/regulator_test.go @@ -1,46 +1,69 @@ package regulation_test import ( - "context" + "fmt" + "net" "testing" "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/valyala/fasthttp" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/regulation" - "github.com/authelia/authelia/v4/internal/utils" ) type RegulatorSuite struct { suite.Suite - ctx context.Context - ctrl *gomock.Controller - storageMock *mocks.MockStorage - config schema.RegulationConfiguration - clock utils.TestingClock + mock *mocks.MockAutheliaCtx } func (s *RegulatorSuite) SetupTest() { - s.ctrl = gomock.NewController(s.T()) - s.storageMock = mocks.NewMockStorage(s.ctrl) - s.ctx = context.Background() - - s.config = schema.RegulationConfiguration{ + s.mock = mocks.NewMockAutheliaCtx(s.T()) + s.mock.Ctx.Configuration.Regulation = schema.RegulationConfiguration{ MaxRetries: 3, BanTime: time.Second * 180, FindTime: time.Second * 30, } - s.clock.Set(time.Now()) + + s.mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, "127.0.0.1") } func (s *RegulatorSuite) TearDownTest() { - s.ctrl.Finish() + s.mock.Ctrl.Finish() +} + +func (s *RegulatorSuite) TestShouldMark() { + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) + + s.mock.StorageMock.EXPECT().AppendAuthenticationLog(s.mock.Ctx, model.AuthenticationAttempt{ + Time: s.mock.Clock.Now(), + Successful: true, + Banned: false, + Username: "john", + Type: "1fa", + RemoteIP: model.NewNullIP(net.ParseIP("127.0.0.1")), + RequestURI: "https://google.com", + RequestMethod: fasthttp.MethodGet, + }) + + s.NoError(regulator.Mark(s.mock.Ctx, true, false, "john", "https://google.com", fasthttp.MethodGet, "1fa")) +} + +func (s *RegulatorSuite) TestShouldHandleRegulateError() { + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) + + s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed")) + + until, err := regulator.Regulate(s.mock.Ctx, "john") + + s.NoError(err) + s.Equal(time.Time{}, until) } func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { @@ -48,17 +71,17 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { { Username: "john", Successful: true, - Time: s.clock.Now().Add(-4 * time.Minute), + Time: s.mock.Clock.Now().Add(-4 * time.Minute), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) } @@ -69,27 +92,27 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime { Username: "john", Successful: false, - Time: s.clock.Now().Add(-1 * time.Second), + Time: s.mock.Clock.Now().Add(-1 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-90 * time.Second), + Time: s.mock.Clock.Now().Add(-90 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-180 * time.Second), + Time: s.mock.Clock.Now().Add(-180 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) } @@ -100,32 +123,32 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-1 * time.Second), + Time: s.mock.Clock.Now().Add(-1 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-4 * time.Second), + Time: s.mock.Clock.Now().Add(-4 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-6 * time.Second), + Time: s.mock.Clock.Now().Add(-6 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-180 * time.Second), + Time: s.mock.Clock.Now().Add(-180 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -138,27 +161,27 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-31 * time.Second), + Time: s.mock.Clock.Now().Add(-31 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-34 * time.Second), + Time: s.mock.Clock.Now().Add(-34 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-36 * time.Second), + Time: s.mock.Clock.Now().Add(-36 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -167,22 +190,22 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-34 * time.Second), + Time: s.mock.Clock.Now().Add(-34 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-36 * time.Second), + Time: s.mock.Clock.Now().Add(-36 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) } @@ -191,7 +214,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-14 * time.Second), + Time: s.mock.Clock.Now().Add(-14 * time.Second), }, // more than 30 seconds elapsed between this auth and the preceding one. // In that case we don't need to regulate the user even though the number @@ -199,22 +222,22 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-94 * time.Second), + Time: s.mock.Clock.Now().Add(-94 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-96 * time.Second), + Time: s.mock.Clock.Now().Add(-96 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) } @@ -223,34 +246,34 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp { Username: "john", Successful: false, - Time: s.clock.Now().Add(-90 * time.Second), + Time: s.mock.Clock.Now().Add(-90 * time.Second), }, { Username: "john", Successful: true, - Time: s.clock.Now().Add(-93 * time.Second), + Time: s.mock.Clock.Now().Add(-93 * time.Second), }, // The user was almost banned but he did a successful attempt. Therefore, even if the next // failure happens within FindTime, he should not be banned. { Username: "john", Successful: false, - Time: s.clock.Now().Add(-94 * time.Second), + Time: s.mock.Clock.Now().Add(-94 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-96 * time.Second), + Time: s.mock.Clock.Now().Add(-96 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) } @@ -265,22 +288,22 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { { Username: "john", Successful: false, - Time: s.clock.Now().Add(-31 * time.Second), + Time: s.mock.Clock.Now().Add(-31 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-34 * time.Second), + Time: s.mock.Clock.Now().Add(-34 * time.Second), }, { Username: "john", Successful: false, - Time: s.clock.Now().Add(-36 * time.Second), + Time: s.mock.Clock.Now().Add(-36 * time.Second), }, } - s.storageMock.EXPECT(). - LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + s.mock.StorageMock.EXPECT(). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) // Check Disabled Functionality. @@ -290,8 +313,8 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { BanTime: time.Second * 180, } - regulator := regulation.NewRegulator(config, s.storageMock, &s.clock) - _, err := regulator.Regulate(s.ctx, "john") + regulator := regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock) + _, err := regulator.Regulate(s.mock.Ctx, "john") assert.NoError(s.T(), err) // Check Enabled Functionality. @@ -301,7 +324,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { BanTime: time.Second * 180, } - regulator = regulation.NewRegulator(config, s.storageMock, &s.clock) - _, err = regulator.Regulate(s.ctx, "john") + regulator = regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock) + _, err = regulator.Regulate(s.mock.Ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } diff --git a/internal/regulation/types.go b/internal/regulation/types.go index d5ad21edd..11c56fec0 100644 --- a/internal/regulation/types.go +++ b/internal/regulation/types.go @@ -16,7 +16,7 @@ type Regulator struct { config schema.RegulationConfiguration - storageProvider storage.RegulatorProvider + store storage.RegulatorProvider clock utils.Clock } diff --git a/internal/storage/const.go b/internal/storage/const.go index 2dec1b5e4..7d7b788fa 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -28,59 +28,6 @@ const ( tableEncryption = "encryption" ) -// OAuth2SessionType represents the potential OAuth 2.0 session types. -type OAuth2SessionType int - -// Representation of specific OAuth 2.0 session types. -const ( - OAuth2SessionTypeAccessToken OAuth2SessionType = iota - OAuth2SessionTypeAuthorizeCode - OAuth2SessionTypeOpenIDConnect - OAuth2SessionTypePAR - OAuth2SessionTypePKCEChallenge - OAuth2SessionTypeRefreshToken -) - -// String returns a string representation of this OAuth2SessionType. -func (s OAuth2SessionType) String() string { - switch s { - case OAuth2SessionTypeAccessToken: - return "access token" - case OAuth2SessionTypeAuthorizeCode: - return "authorization code" - case OAuth2SessionTypeOpenIDConnect: - return "openid connect" - case OAuth2SessionTypePAR: - return "pushed authorization request context" - case OAuth2SessionTypePKCEChallenge: - return "pkce challenge" - case OAuth2SessionTypeRefreshToken: - return "refresh token" - default: - return "invalid" - } -} - -// Table returns the table name for this session type. -func (s OAuth2SessionType) Table() string { - switch s { - case OAuth2SessionTypeAccessToken: - return tableOAuth2AccessTokenSession - case OAuth2SessionTypeAuthorizeCode: - return tableOAuth2AuthorizeCodeSession - case OAuth2SessionTypeOpenIDConnect: - return tableOAuth2OpenIDConnectSession - case OAuth2SessionTypePAR: - return tableOAuth2PARContext - case OAuth2SessionTypePKCEChallenge: - return tableOAuth2PKCERequestSession - case OAuth2SessionTypeRefreshToken: - return tableOAuth2RefreshTokenSession - default: - return "" - } -} - const ( encryptionNameCheck = "check" ) diff --git a/internal/storage/types.go b/internal/storage/types.go index 7537c7695..5954131e0 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -93,3 +93,56 @@ func (r EncryptionValidationTableResult) ResultDescriptor() string { return "SUCCESS" } + +// OAuth2SessionType represents the potential OAuth 2.0 session types. +type OAuth2SessionType int + +// Representation of specific OAuth 2.0 session types. +const ( + OAuth2SessionTypeAccessToken OAuth2SessionType = iota + OAuth2SessionTypeAuthorizeCode + OAuth2SessionTypeOpenIDConnect + OAuth2SessionTypePAR + OAuth2SessionTypePKCEChallenge + OAuth2SessionTypeRefreshToken +) + +// String returns a string representation of this OAuth2SessionType. +func (s OAuth2SessionType) String() string { + switch s { + case OAuth2SessionTypeAccessToken: + return "access token" + case OAuth2SessionTypeAuthorizeCode: + return "authorization code" + case OAuth2SessionTypeOpenIDConnect: + return "openid connect" + case OAuth2SessionTypePAR: + return "pushed authorization request context" + case OAuth2SessionTypePKCEChallenge: + return "pkce challenge" + case OAuth2SessionTypeRefreshToken: + return "refresh token" + default: + return "invalid" + } +} + +// Table returns the table name for this session type. +func (s OAuth2SessionType) Table() string { + switch s { + case OAuth2SessionTypeAccessToken: + return tableOAuth2AccessTokenSession + case OAuth2SessionTypeAuthorizeCode: + return tableOAuth2AuthorizeCodeSession + case OAuth2SessionTypeOpenIDConnect: + return tableOAuth2OpenIDConnectSession + case OAuth2SessionTypePAR: + return tableOAuth2PARContext + case OAuth2SessionTypePKCEChallenge: + return tableOAuth2PKCERequestSession + case OAuth2SessionTypeRefreshToken: + return tableOAuth2RefreshTokenSession + default: + return "" + } +} diff --git a/internal/storage/types_test.go b/internal/storage/types_test.go new file mode 100644 index 000000000..e4903ee26 --- /dev/null +++ b/internal/storage/types_test.go @@ -0,0 +1,87 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptionValidationResult(t *testing.T) { + result := &EncryptionValidationResult{ + InvalidCheckValue: false, + } + + assert.True(t, result.Success()) + assert.True(t, result.Checked()) + + result = &EncryptionValidationResult{ + InvalidCheckValue: true, + } + + assert.False(t, result.Success()) + assert.True(t, result.Checked()) + + result = &EncryptionValidationResult{ + InvalidCheckValue: false, + Tables: map[string]EncryptionValidationTableResult{ + tableWebAuthnDevices: { + Invalid: 10, + Total: 20, + }, + }, + } + assert.Equal(t, "FAILURE", result.Tables[tableWebAuthnDevices].ResultDescriptor()) + + assert.False(t, result.Success()) + assert.True(t, result.Checked()) + + result = &EncryptionValidationResult{ + InvalidCheckValue: false, + Tables: map[string]EncryptionValidationTableResult{ + tableWebAuthnDevices: { + Error: fmt.Errorf("failed to check table"), + }, + }, + } + + assert.False(t, result.Success()) + assert.False(t, result.Checked()) + assert.Equal(t, "N/A", result.Tables[tableWebAuthnDevices].ResultDescriptor()) + + result = &EncryptionValidationResult{ + InvalidCheckValue: false, + Tables: map[string]EncryptionValidationTableResult{ + tableWebAuthnDevices: { + Total: 20, + }, + }, + } + + assert.True(t, result.Success()) + assert.True(t, result.Checked()) + assert.Equal(t, "SUCCESS", result.Tables[tableWebAuthnDevices].ResultDescriptor()) +} + +func TestOAuth2SessionType(t *testing.T) { + assert.Equal(t, "access token", OAuth2SessionTypeAccessToken.String()) + assert.Equal(t, tableOAuth2AccessTokenSession, OAuth2SessionTypeAccessToken.Table()) + + assert.Equal(t, "authorization code", OAuth2SessionTypeAuthorizeCode.String()) + assert.Equal(t, tableOAuth2AuthorizeCodeSession, OAuth2SessionTypeAuthorizeCode.Table()) + + assert.Equal(t, "openid connect", OAuth2SessionTypeOpenIDConnect.String()) + assert.Equal(t, tableOAuth2OpenIDConnectSession, OAuth2SessionTypeOpenIDConnect.Table()) + + assert.Equal(t, "pushed authorization request context", OAuth2SessionTypePAR.String()) + assert.Equal(t, tableOAuth2PARContext, OAuth2SessionTypePAR.Table()) + + assert.Equal(t, "pkce challenge", OAuth2SessionTypePKCEChallenge.String()) + assert.Equal(t, tableOAuth2PKCERequestSession, OAuth2SessionTypePKCEChallenge.Table()) + + assert.Equal(t, "refresh token", OAuth2SessionTypeRefreshToken.String()) + assert.Equal(t, tableOAuth2RefreshTokenSession, OAuth2SessionTypeRefreshToken.Table()) + + assert.Equal(t, "invalid", OAuth2SessionType(-1).String()) + assert.Equal(t, "", OAuth2SessionType(-1).Table()) +} diff --git a/internal/utils/aes_test.go b/internal/utils/aes_test.go index 47482c664..b04234540 100644 --- a/internal/utils/aes_test.go +++ b/internal/utils/aes_test.go @@ -8,7 +8,7 @@ import ( ) func TestShouldEncryptAndDecriptUsingAES(t *testing.T) { - var key [32]byte = sha256.Sum256([]byte("the key")) + var key = sha256.Sum256([]byte("the key")) var secret = "the secret" @@ -22,7 +22,7 @@ func TestShouldEncryptAndDecriptUsingAES(t *testing.T) { } func TestShouldFailDecryptOnInvalidKey(t *testing.T) { - var key [32]byte = sha256.Sum256([]byte("the key")) + var key = sha256.Sum256([]byte("the key")) var secret = "the secret" @@ -37,7 +37,7 @@ func TestShouldFailDecryptOnInvalidKey(t *testing.T) { } func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) { - var key [32]byte = sha256.Sum256([]byte("the key")) + var key = sha256.Sum256([]byte("the key")) encryptedSecret := []byte("abc123") diff --git a/internal/utils/io.go b/internal/utils/io.go deleted file mode 100644 index eacaf0386..000000000 --- a/internal/utils/io.go +++ /dev/null @@ -1,34 +0,0 @@ -package utils - -import ( - "errors" - "io" -) - -// NewWriteCloser creates a new io.WriteCloser from an io.Writer. -func NewWriteCloser(wr io.Writer) io.WriteCloser { - return &WriteCloser{wr: wr} -} - -// WriteCloser is a io.Writer with an io.Closer. -type WriteCloser struct { - wr io.Writer - - closed bool -} - -// Write to the io.Writer. -func (w *WriteCloser) Write(p []byte) (n int, err error) { - if w.closed { - return -1, errors.New("already closed") - } - - return w.wr.Write(p) -} - -// Close the io.Closer. -func (w *WriteCloser) Close() error { - w.closed = true - - return nil -} diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index c12978739..3a9147111 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -8,6 +8,97 @@ import ( "github.com/stretchr/testify/require" ) +func TestIsStringAbsURL(t *testing.T) { + testCases := []struct { + name string + have string + err string + }{ + { + "ShouldBeAbs", + "https://google.com", + "", + }, + { + "ShouldNotBeAbs", + "google.com", + "could not parse 'google.com' as a URL", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + theError := IsStringAbsURL(tc.have) + + if tc.err == "" { + assert.NoError(t, theError) + } else { + assert.EqualError(t, theError, tc.err) + } + }) + } +} + +func TestIsStringInSliceF(t *testing.T) { + testCases := []struct { + name string + needle string + haystack []string + isEqual func(needle, item string) bool + expected bool + }{ + { + "ShouldBePresent", + "good", + []string{"good"}, + func(needle, item string) bool { + return needle == item + }, + true, + }, + { + "ShouldNotBePresent", + "bad", + []string{"good"}, + func(needle, item string) bool { + return needle == item + }, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsStringInSliceF(tc.needle, tc.haystack, tc.isEqual)) + }) + } +} + +func TestStringHTMLEscape(t *testing.T) { + testCases := []struct { + name string + have string + expected string + }{ + { + "ShouldNotAlterAlphaNum", + "abc123", + "abc123", + }, + { + "ShouldEscapeSpecial", + "abc123><@#@", + "abc123><@#@", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, StringHTMLEscape(tc.have)) + }) + } +} + func TestStringSplitDelimitedEscaped(t *testing.T) { testCases := []struct { desc, have string diff --git a/internal/utils/time_test.go b/internal/utils/time_test.go index d1e1b6681..c98ef73c8 100644 --- a/internal/utils/time_test.go +++ b/internal/utils/time_test.go @@ -209,11 +209,51 @@ func TestParseTimeString(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - index, actual, err := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts) + index, actualA, errA := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts) + actualB, errB := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts) + actualC, errC := ParseTimeString(tc.have) + + if tc.err == "" { + assert.NoError(t, errA) + assert.NoError(t, errB) + assert.NoError(t, errC) + + assert.Equal(t, tc.index, index) + assert.Equal(t, tc.expected.UnixNano(), actualA.UnixNano()) + assert.Equal(t, tc.expected.UnixNano(), actualB.UnixNano()) + assert.Equal(t, tc.expected.UnixNano(), actualC.UnixNano()) + } else { + assert.EqualError(t, errA, tc.err) + assert.EqualError(t, errB, tc.err) + assert.EqualError(t, errC, tc.err) + } + }) + } +} + +func TestParseTimeStringWithLayouts(t *testing.T) { + testCases := []struct { + name string + have string + index int + expected time.Time + err string + }{ + {"ShouldParseIntegerAsUnix", "1675899060", -1, time.Unix(1675899060, 0), ""}, + {"ShouldParseIntegerAsUnixMilli", "1675899060000", -2, time.Unix(1675899060, 0), ""}, + {"ShouldParseIntegerAsUnixMicro", "1675899060000000", -3, time.Unix(1675899060, 0), ""}, + {"ShouldNotParseSuperLargeInteger", "9999999999999999999999999999999999999999", -999, time.Unix(0, 0), "time value was detected as an integer but the integer could not be parsed: strconv.ParseInt: parsing \"9999999999999999999999999999999999999999\": value out of range"}, + {"ShouldParseSimpleTime", "Jan 2 15:04:05 2006", 0, time.Unix(1136214245, 0), ""}, + {"ShouldNotParseInvalidTime", "abc", -998, time.Unix(0, 0), "failed to find a suitable time layout for time 'abc'"}, + {"ShouldMatchDate", "2020-05-01", 6, time.Unix(1588291200, 0), ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual, err := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts) if tc.err == "" { assert.NoError(t, err) - assert.Equal(t, tc.index, index) assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano()) } else { assert.EqualError(t, err, tc.err) diff --git a/internal/utils/url_test.go b/internal/utils/url_test.go index 03b51540f..dc8ab1508 100644 --- a/internal/utils/url_test.go +++ b/internal/utils/url_test.go @@ -59,3 +59,8 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) { assert.False(t, isURLSafe("https://secure.example.comc", "example.com")) assert.False(t, isURLSafe("https://secure.example.co", "example.com")) } + +func TestHasDomainSuffix(t *testing.T) { + assert.False(t, HasDomainSuffix("abc", "")) + assert.False(t, HasDomainSuffix("", "")) +}