test: add misc missing tests (#5479)

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
pull/5481/head
James Elliott 2023-05-24 22:33:05 +10:00 committed by GitHub
parent e784a72735
commit 2e8a460a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 559 additions and 177 deletions

View File

@ -0,0 +1,20 @@
package metrics
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewPrometheus(t *testing.T) {
p := NewPrometheus()
assert.NotNil(t, p)
p.RecordRequest("400", "GET", time.Second)
p.RecordAuthz("400")
p.RecordAuthn(true, false, "WebAuthn")
p.RecordAuthn(true, false, "1fa")
p.RecordAuthenticationDuration(true, time.Second)
}

View File

@ -0,0 +1,75 @@
package random
import (
"math/big"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCryptographical(t *testing.T) {
p := &Cryptographical{}
data := make([]byte, 10)
n, err := p.Read(data)
assert.Equal(t, 10, n)
assert.NoError(t, err)
data2, err := p.BytesErr()
assert.NoError(t, err)
assert.Len(t, data2, 72)
data2 = p.Bytes()
assert.Len(t, data2, 72)
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
assert.Len(t, data2, 74)
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
assert.NoError(t, err)
assert.Len(t, data2, 76)
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
assert.NoError(t, err)
assert.Len(t, data2, 72)
strdata := p.StringCustom(10, CharSetAlphabetic)
assert.Len(t, strdata, 10)
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
assert.NoError(t, err)
assert.Len(t, strdata, 11)
i := p.Intn(999)
assert.Greater(t, i, 0)
assert.Less(t, i, 999)
i, err = p.IntnErr(999)
assert.NoError(t, err)
assert.Greater(t, i, 0)
assert.Less(t, i, 999)
i, err = p.IntnErr(-4)
assert.EqualError(t, err, "n must be more than 0")
assert.Equal(t, 0, i)
bi := p.Int(big.NewInt(999))
assert.Greater(t, bi.Int64(), int64(0))
assert.Less(t, bi.Int64(), int64(999))
bi = p.Int(nil)
assert.Equal(t, int64(-1), bi.Int64())
bi, err = p.IntErr(nil)
assert.Nil(t, bi)
assert.EqualError(t, err, "max is required")
bi, err = p.IntErr(big.NewInt(-1))
assert.Nil(t, bi)
assert.EqualError(t, err, "max must be 1 or more")
prime, err := p.Prime(64)
assert.NoError(t, err)
assert.NotNil(t, prime)
}

View File

@ -112,6 +112,10 @@ func (r *Mathematical) Intn(n int) int {
// IntnErr returns a random int error combination with a maximum of n. // IntnErr returns a random int error combination with a maximum of n.
func (r *Mathematical) IntnErr(n int) (output int, err error) { func (r *Mathematical) IntnErr(n int) (output int, err error) {
if n <= 0 {
return 0, fmt.Errorf("n must be more than 0")
}
return r.Intn(n), nil return r.Intn(n), nil
} }
@ -132,15 +136,11 @@ func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) {
return nil, fmt.Errorf("max is required") return nil, fmt.Errorf("max is required")
} }
if max.Sign() <= 0 { if max.Int64() <= 0 {
return nil, fmt.Errorf("max must be 1 or more") return nil, fmt.Errorf("max must be 1 or more")
} }
r.lock.Lock() return big.NewInt(int64(r.Intn(int(max.Int64())))), nil
defer r.lock.Unlock()
return big.NewInt(int64(r.Intn(max.Sign()))), nil
} }
// Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any // Prime returns a number of the given bit length that is prime with high probability. Prime will return error for any

View File

@ -0,0 +1,75 @@
package random
import (
"math/big"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMathematical(t *testing.T) {
p := NewMathematical()
data := make([]byte, 10)
n, err := p.Read(data)
assert.Equal(t, 10, n)
assert.NoError(t, err)
data2, err := p.BytesErr()
assert.NoError(t, err)
assert.Len(t, data2, 72)
data2 = p.Bytes()
assert.Len(t, data2, 72)
data2 = p.BytesCustom(74, []byte(CharSetAlphabetic))
assert.Len(t, data2, 74)
data2, err = p.BytesCustomErr(76, []byte(CharSetAlphabetic))
assert.NoError(t, err)
assert.Len(t, data2, 76)
data2, err = p.BytesCustomErr(-5, []byte(CharSetAlphabetic))
assert.NoError(t, err)
assert.Len(t, data2, 72)
strdata := p.StringCustom(10, CharSetAlphabetic)
assert.Len(t, strdata, 10)
strdata, err = p.StringCustomErr(11, CharSetAlphabetic)
assert.NoError(t, err)
assert.Len(t, strdata, 11)
i := p.Intn(999)
assert.Greater(t, i, 0)
assert.Less(t, i, 999)
i, err = p.IntnErr(999)
assert.NoError(t, err)
assert.Greater(t, i, 0)
assert.Less(t, i, 999)
i, err = p.IntnErr(-4)
assert.EqualError(t, err, "n must be more than 0")
assert.Equal(t, 0, i)
bi := p.Int(big.NewInt(999))
assert.Greater(t, bi.Int64(), int64(0))
assert.Less(t, bi.Int64(), int64(999))
bi = p.Int(nil)
assert.Equal(t, int64(-1), bi.Int64())
bi, err = p.IntErr(nil)
assert.Nil(t, bi)
assert.EqualError(t, err, "max is required")
bi, err = p.IntErr(big.NewInt(-1))
assert.Nil(t, bi)
assert.EqualError(t, err, "max must be 1 or more")
prime, err := p.Prime(64)
assert.NoError(t, err)
assert.NotNil(t, prime)
}

View File

@ -12,10 +12,10 @@ import (
) )
// NewRegulator create a regulator instance. // NewRegulator create a regulator instance.
func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { func NewRegulator(config schema.RegulationConfiguration, store storage.RegulatorProvider, clock utils.Clock) *Regulator {
return &Regulator{ return &Regulator{
enabled: config.MaxRetries > 0, enabled: config.MaxRetries > 0,
storageProvider: provider, store: store,
clock: clock, clock: clock,
config: config, config: config,
} }
@ -26,7 +26,7 @@ func NewRegulator(config schema.RegulationConfiguration, provider storage.Regula
func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error { func (r *Regulator) Mark(ctx Context, successful, banned bool, username, requestURI, requestMethod, authType string) error {
ctx.RecordAuthn(successful, banned, strings.ToLower(authType)) ctx.RecordAuthn(successful, banned, strings.ToLower(authType))
return r.storageProvider.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{ return r.store.AppendAuthenticationLog(ctx, model.AuthenticationAttempt{
Time: r.clock.Now(), Time: r.clock.Now(),
Successful: successful, Successful: successful,
Banned: banned, Banned: banned,
@ -46,7 +46,7 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
return time.Time{}, nil return time.Time{}, nil
} }
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0) attempts, err := r.store.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
if err != nil { if err != nil {
return time.Time{}, nil return time.Time{}, nil
} }

View File

@ -1,46 +1,69 @@
package regulation_test package regulation_test
import ( import (
"context" "fmt"
"net"
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/regulation"
"github.com/authelia/authelia/v4/internal/utils"
) )
type RegulatorSuite struct { type RegulatorSuite struct {
suite.Suite suite.Suite
ctx context.Context mock *mocks.MockAutheliaCtx
ctrl *gomock.Controller
storageMock *mocks.MockStorage
config schema.RegulationConfiguration
clock utils.TestingClock
} }
func (s *RegulatorSuite) SetupTest() { func (s *RegulatorSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T()) s.mock = mocks.NewMockAutheliaCtx(s.T())
s.storageMock = mocks.NewMockStorage(s.ctrl) s.mock.Ctx.Configuration.Regulation = schema.RegulationConfiguration{
s.ctx = context.Background()
s.config = schema.RegulationConfiguration{
MaxRetries: 3, MaxRetries: 3,
BanTime: time.Second * 180, BanTime: time.Second * 180,
FindTime: time.Second * 30, FindTime: time.Second * 30,
} }
s.clock.Set(time.Now())
s.mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedFor, "127.0.0.1")
} }
func (s *RegulatorSuite) TearDownTest() { func (s *RegulatorSuite) TearDownTest() {
s.ctrl.Finish() s.mock.Ctrl.Finish()
}
func (s *RegulatorSuite) TestShouldMark() {
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
s.mock.StorageMock.EXPECT().AppendAuthenticationLog(s.mock.Ctx, model.AuthenticationAttempt{
Time: s.mock.Clock.Now(),
Successful: true,
Banned: false,
Username: "john",
Type: "1fa",
RemoteIP: model.NewNullIP(net.ParseIP("127.0.0.1")),
RequestURI: "https://google.com",
RequestMethod: fasthttp.MethodGet,
})
s.NoError(regulator.Mark(s.mock.Ctx, true, false, "john", "https://google.com", fasthttp.MethodGet, "1fa"))
}
func (s *RegulatorSuite) TestShouldHandleRegulateError() {
regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed"))
until, err := regulator.Regulate(s.mock.Ctx, "john")
s.NoError(err)
s.Equal(time.Time{}, until)
} }
func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
@ -48,17 +71,17 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
{ {
Username: "john", Username: "john",
Successful: true, Successful: true,
Time: s.clock.Now().Add(-4 * time.Minute), Time: s.mock.Clock.Now().Add(-4 * time.Minute),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
} }
@ -69,27 +92,27 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-1 * time.Second), Time: s.mock.Clock.Now().Add(-1 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-90 * time.Second), Time: s.mock.Clock.Now().Add(-90 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-180 * time.Second), Time: s.mock.Clock.Now().Add(-180 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
} }
@ -100,32 +123,32 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-1 * time.Second), Time: s.mock.Clock.Now().Add(-1 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-4 * time.Second), Time: s.mock.Clock.Now().Add(-4 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-6 * time.Second), Time: s.mock.Clock.Now().Add(-6 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-180 * time.Second), Time: s.mock.Clock.Now().Add(-180 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
} }
@ -138,27 +161,27 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-31 * time.Second), Time: s.mock.Clock.Now().Add(-31 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-34 * time.Second), Time: s.mock.Clock.Now().Add(-34 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-36 * time.Second), Time: s.mock.Clock.Now().Add(-36 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
} }
@ -167,22 +190,22 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-34 * time.Second), Time: s.mock.Clock.Now().Add(-34 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-36 * time.Second), Time: s.mock.Clock.Now().Add(-36 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
} }
@ -191,7 +214,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-14 * time.Second), Time: s.mock.Clock.Now().Add(-14 * time.Second),
}, },
// more than 30 seconds elapsed between this auth and the preceding one. // more than 30 seconds elapsed between this auth and the preceding one.
// In that case we don't need to regulate the user even though the number // In that case we don't need to regulate the user even though the number
@ -199,22 +222,22 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-94 * time.Second), Time: s.mock.Clock.Now().Add(-94 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-96 * time.Second), Time: s.mock.Clock.Now().Add(-96 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
} }
@ -223,34 +246,34 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-90 * time.Second), Time: s.mock.Clock.Now().Add(-90 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: true, Successful: true,
Time: s.clock.Now().Add(-93 * time.Second), Time: s.mock.Clock.Now().Add(-93 * time.Second),
}, },
// The user was almost banned but he did a successful attempt. Therefore, even if the next // The user was almost banned but he did a successful attempt. Therefore, even if the next
// failure happens within FindTime, he should not be banned. // failure happens within FindTime, he should not be banned.
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-94 * time.Second), Time: s.mock.Clock.Now().Add(-94 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-96 * time.Second), Time: s.mock.Clock.Now().Add(-96 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
} }
@ -265,22 +288,22 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-31 * time.Second), Time: s.mock.Clock.Now().Add(-31 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-34 * time.Second), Time: s.mock.Clock.Now().Add(-34 * time.Second),
}, },
{ {
Username: "john", Username: "john",
Successful: false, Successful: false,
Time: s.clock.Now().Add(-36 * time.Second), Time: s.mock.Clock.Now().Add(-36 * time.Second),
}, },
} }
s.storageMock.EXPECT(). s.mock.StorageMock.EXPECT().
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil) Return(attemptsInDB, nil)
// Check Disabled Functionality. // Check Disabled Functionality.
@ -290,8 +313,8 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
BanTime: time.Second * 180, BanTime: time.Second * 180,
} }
regulator := regulation.NewRegulator(config, s.storageMock, &s.clock) regulator := regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
_, err := regulator.Regulate(s.ctx, "john") _, err := regulator.Regulate(s.mock.Ctx, "john")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
// Check Enabled Functionality. // Check Enabled Functionality.
@ -301,7 +324,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
BanTime: time.Second * 180, BanTime: time.Second * 180,
} }
regulator = regulation.NewRegulator(config, s.storageMock, &s.clock) regulator = regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock)
_, err = regulator.Regulate(s.ctx, "john") _, err = regulator.Regulate(s.mock.Ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err) assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
} }

View File

@ -16,7 +16,7 @@ type Regulator struct {
config schema.RegulationConfiguration config schema.RegulationConfiguration
storageProvider storage.RegulatorProvider store storage.RegulatorProvider
clock utils.Clock clock utils.Clock
} }

View File

@ -28,59 +28,6 @@ const (
tableEncryption = "encryption" tableEncryption = "encryption"
) )
// OAuth2SessionType represents the potential OAuth 2.0 session types.
type OAuth2SessionType int
// Representation of specific OAuth 2.0 session types.
const (
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
OAuth2SessionTypeAuthorizeCode
OAuth2SessionTypeOpenIDConnect
OAuth2SessionTypePAR
OAuth2SessionTypePKCEChallenge
OAuth2SessionTypeRefreshToken
)
// String returns a string representation of this OAuth2SessionType.
func (s OAuth2SessionType) String() string {
switch s {
case OAuth2SessionTypeAccessToken:
return "access token"
case OAuth2SessionTypeAuthorizeCode:
return "authorization code"
case OAuth2SessionTypeOpenIDConnect:
return "openid connect"
case OAuth2SessionTypePAR:
return "pushed authorization request context"
case OAuth2SessionTypePKCEChallenge:
return "pkce challenge"
case OAuth2SessionTypeRefreshToken:
return "refresh token"
default:
return "invalid"
}
}
// Table returns the table name for this session type.
func (s OAuth2SessionType) Table() string {
switch s {
case OAuth2SessionTypeAccessToken:
return tableOAuth2AccessTokenSession
case OAuth2SessionTypeAuthorizeCode:
return tableOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
return tableOAuth2OpenIDConnectSession
case OAuth2SessionTypePAR:
return tableOAuth2PARContext
case OAuth2SessionTypePKCEChallenge:
return tableOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
return tableOAuth2RefreshTokenSession
default:
return ""
}
}
const ( const (
encryptionNameCheck = "check" encryptionNameCheck = "check"
) )

View File

@ -93,3 +93,56 @@ func (r EncryptionValidationTableResult) ResultDescriptor() string {
return "SUCCESS" return "SUCCESS"
} }
// OAuth2SessionType represents the potential OAuth 2.0 session types.
type OAuth2SessionType int
// Representation of specific OAuth 2.0 session types.
const (
OAuth2SessionTypeAccessToken OAuth2SessionType = iota
OAuth2SessionTypeAuthorizeCode
OAuth2SessionTypeOpenIDConnect
OAuth2SessionTypePAR
OAuth2SessionTypePKCEChallenge
OAuth2SessionTypeRefreshToken
)
// String returns a string representation of this OAuth2SessionType.
func (s OAuth2SessionType) String() string {
switch s {
case OAuth2SessionTypeAccessToken:
return "access token"
case OAuth2SessionTypeAuthorizeCode:
return "authorization code"
case OAuth2SessionTypeOpenIDConnect:
return "openid connect"
case OAuth2SessionTypePAR:
return "pushed authorization request context"
case OAuth2SessionTypePKCEChallenge:
return "pkce challenge"
case OAuth2SessionTypeRefreshToken:
return "refresh token"
default:
return "invalid"
}
}
// Table returns the table name for this session type.
func (s OAuth2SessionType) Table() string {
switch s {
case OAuth2SessionTypeAccessToken:
return tableOAuth2AccessTokenSession
case OAuth2SessionTypeAuthorizeCode:
return tableOAuth2AuthorizeCodeSession
case OAuth2SessionTypeOpenIDConnect:
return tableOAuth2OpenIDConnectSession
case OAuth2SessionTypePAR:
return tableOAuth2PARContext
case OAuth2SessionTypePKCEChallenge:
return tableOAuth2PKCERequestSession
case OAuth2SessionTypeRefreshToken:
return tableOAuth2RefreshTokenSession
default:
return ""
}
}

View File

@ -0,0 +1,87 @@
package storage
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEncryptionValidationResult(t *testing.T) {
result := &EncryptionValidationResult{
InvalidCheckValue: false,
}
assert.True(t, result.Success())
assert.True(t, result.Checked())
result = &EncryptionValidationResult{
InvalidCheckValue: true,
}
assert.False(t, result.Success())
assert.True(t, result.Checked())
result = &EncryptionValidationResult{
InvalidCheckValue: false,
Tables: map[string]EncryptionValidationTableResult{
tableWebAuthnDevices: {
Invalid: 10,
Total: 20,
},
},
}
assert.Equal(t, "FAILURE", result.Tables[tableWebAuthnDevices].ResultDescriptor())
assert.False(t, result.Success())
assert.True(t, result.Checked())
result = &EncryptionValidationResult{
InvalidCheckValue: false,
Tables: map[string]EncryptionValidationTableResult{
tableWebAuthnDevices: {
Error: fmt.Errorf("failed to check table"),
},
},
}
assert.False(t, result.Success())
assert.False(t, result.Checked())
assert.Equal(t, "N/A", result.Tables[tableWebAuthnDevices].ResultDescriptor())
result = &EncryptionValidationResult{
InvalidCheckValue: false,
Tables: map[string]EncryptionValidationTableResult{
tableWebAuthnDevices: {
Total: 20,
},
},
}
assert.True(t, result.Success())
assert.True(t, result.Checked())
assert.Equal(t, "SUCCESS", result.Tables[tableWebAuthnDevices].ResultDescriptor())
}
func TestOAuth2SessionType(t *testing.T) {
assert.Equal(t, "access token", OAuth2SessionTypeAccessToken.String())
assert.Equal(t, tableOAuth2AccessTokenSession, OAuth2SessionTypeAccessToken.Table())
assert.Equal(t, "authorization code", OAuth2SessionTypeAuthorizeCode.String())
assert.Equal(t, tableOAuth2AuthorizeCodeSession, OAuth2SessionTypeAuthorizeCode.Table())
assert.Equal(t, "openid connect", OAuth2SessionTypeOpenIDConnect.String())
assert.Equal(t, tableOAuth2OpenIDConnectSession, OAuth2SessionTypeOpenIDConnect.Table())
assert.Equal(t, "pushed authorization request context", OAuth2SessionTypePAR.String())
assert.Equal(t, tableOAuth2PARContext, OAuth2SessionTypePAR.Table())
assert.Equal(t, "pkce challenge", OAuth2SessionTypePKCEChallenge.String())
assert.Equal(t, tableOAuth2PKCERequestSession, OAuth2SessionTypePKCEChallenge.Table())
assert.Equal(t, "refresh token", OAuth2SessionTypeRefreshToken.String())
assert.Equal(t, tableOAuth2RefreshTokenSession, OAuth2SessionTypeRefreshToken.Table())
assert.Equal(t, "invalid", OAuth2SessionType(-1).String())
assert.Equal(t, "", OAuth2SessionType(-1).Table())
}

View File

@ -8,7 +8,7 @@ import (
) )
func TestShouldEncryptAndDecriptUsingAES(t *testing.T) { func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
var key [32]byte = sha256.Sum256([]byte("the key")) var key = sha256.Sum256([]byte("the key"))
var secret = "the secret" var secret = "the secret"
@ -22,7 +22,7 @@ func TestShouldEncryptAndDecriptUsingAES(t *testing.T) {
} }
func TestShouldFailDecryptOnInvalidKey(t *testing.T) { func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
var key [32]byte = sha256.Sum256([]byte("the key")) var key = sha256.Sum256([]byte("the key"))
var secret = "the secret" var secret = "the secret"
@ -37,7 +37,7 @@ func TestShouldFailDecryptOnInvalidKey(t *testing.T) {
} }
func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) { func TestShouldFailDecryptOnInvalidCypherText(t *testing.T) {
var key [32]byte = sha256.Sum256([]byte("the key")) var key = sha256.Sum256([]byte("the key"))
encryptedSecret := []byte("abc123") encryptedSecret := []byte("abc123")

View File

@ -1,34 +0,0 @@
package utils
import (
"errors"
"io"
)
// NewWriteCloser creates a new io.WriteCloser from an io.Writer.
func NewWriteCloser(wr io.Writer) io.WriteCloser {
return &WriteCloser{wr: wr}
}
// WriteCloser is a io.Writer with an io.Closer.
type WriteCloser struct {
wr io.Writer
closed bool
}
// Write to the io.Writer.
func (w *WriteCloser) Write(p []byte) (n int, err error) {
if w.closed {
return -1, errors.New("already closed")
}
return w.wr.Write(p)
}
// Close the io.Closer.
func (w *WriteCloser) Close() error {
w.closed = true
return nil
}

View File

@ -8,6 +8,97 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestIsStringAbsURL(t *testing.T) {
testCases := []struct {
name string
have string
err string
}{
{
"ShouldBeAbs",
"https://google.com",
"",
},
{
"ShouldNotBeAbs",
"google.com",
"could not parse 'google.com' as a URL",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
theError := IsStringAbsURL(tc.have)
if tc.err == "" {
assert.NoError(t, theError)
} else {
assert.EqualError(t, theError, tc.err)
}
})
}
}
func TestIsStringInSliceF(t *testing.T) {
testCases := []struct {
name string
needle string
haystack []string
isEqual func(needle, item string) bool
expected bool
}{
{
"ShouldBePresent",
"good",
[]string{"good"},
func(needle, item string) bool {
return needle == item
},
true,
},
{
"ShouldNotBePresent",
"bad",
[]string{"good"},
func(needle, item string) bool {
return needle == item
},
false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, IsStringInSliceF(tc.needle, tc.haystack, tc.isEqual))
})
}
}
func TestStringHTMLEscape(t *testing.T) {
testCases := []struct {
name string
have string
expected string
}{
{
"ShouldNotAlterAlphaNum",
"abc123",
"abc123",
},
{
"ShouldEscapeSpecial",
"abc123><@#@",
"abc123&gt;&lt;@#@",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, StringHTMLEscape(tc.have))
})
}
}
func TestStringSplitDelimitedEscaped(t *testing.T) { func TestStringSplitDelimitedEscaped(t *testing.T) {
testCases := []struct { testCases := []struct {
desc, have string desc, have string

View File

@ -209,11 +209,51 @@ func TestParseTimeString(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
index, actual, err := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts) index, actualA, errA := matchParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
actualB, errB := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
actualC, errC := ParseTimeString(tc.have)
if tc.err == "" {
assert.NoError(t, errA)
assert.NoError(t, errB)
assert.NoError(t, errC)
assert.Equal(t, tc.index, index)
assert.Equal(t, tc.expected.UnixNano(), actualA.UnixNano())
assert.Equal(t, tc.expected.UnixNano(), actualB.UnixNano())
assert.Equal(t, tc.expected.UnixNano(), actualC.UnixNano())
} else {
assert.EqualError(t, errA, tc.err)
assert.EqualError(t, errB, tc.err)
assert.EqualError(t, errC, tc.err)
}
})
}
}
func TestParseTimeStringWithLayouts(t *testing.T) {
testCases := []struct {
name string
have string
index int
expected time.Time
err string
}{
{"ShouldParseIntegerAsUnix", "1675899060", -1, time.Unix(1675899060, 0), ""},
{"ShouldParseIntegerAsUnixMilli", "1675899060000", -2, time.Unix(1675899060, 0), ""},
{"ShouldParseIntegerAsUnixMicro", "1675899060000000", -3, time.Unix(1675899060, 0), ""},
{"ShouldNotParseSuperLargeInteger", "9999999999999999999999999999999999999999", -999, time.Unix(0, 0), "time value was detected as an integer but the integer could not be parsed: strconv.ParseInt: parsing \"9999999999999999999999999999999999999999\": value out of range"},
{"ShouldParseSimpleTime", "Jan 2 15:04:05 2006", 0, time.Unix(1136214245, 0), ""},
{"ShouldNotParseInvalidTime", "abc", -998, time.Unix(0, 0), "failed to find a suitable time layout for time 'abc'"},
{"ShouldMatchDate", "2020-05-01", 6, time.Unix(1588291200, 0), ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual, err := ParseTimeStringWithLayouts(tc.have, StandardTimeLayouts)
if tc.err == "" { if tc.err == "" {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tc.index, index)
assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano()) assert.Equal(t, tc.expected.UnixNano(), actual.UnixNano())
} else { } else {
assert.EqualError(t, err, tc.err) assert.EqualError(t, err, tc.err)

View File

@ -59,3 +59,8 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
assert.False(t, isURLSafe("https://secure.example.comc", "example.com")) assert.False(t, isURLSafe("https://secure.example.comc", "example.com"))
assert.False(t, isURLSafe("https://secure.example.co", "example.com")) assert.False(t, isURLSafe("https://secure.example.co", "example.com"))
} }
func TestHasDomainSuffix(t *testing.T) {
assert.False(t, HasDomainSuffix("abc", ""))
assert.False(t, HasDomainSuffix("", ""))
}