362 lines
12 KiB
Go
362 lines
12 KiB
Go
package middlewares_test
|
|
|
|
import (
|
|
"fmt"
|
|
"net/mail"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/authelia/authelia/v4/internal/authentication"
|
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
|
"github.com/authelia/authelia/v4/internal/mocks"
|
|
"github.com/authelia/authelia/v4/internal/model"
|
|
"github.com/authelia/authelia/v4/internal/session"
|
|
)
|
|
|
|
const testJWTSecret = "abc"
|
|
|
|
func newArgs(retriever func(ctx *middlewares.AutheliaCtx) (*session.Identity, error)) middlewares.IdentityVerificationStartArgs {
|
|
return middlewares.IdentityVerificationStartArgs{
|
|
ActionClaim: "Claim",
|
|
MailButtonContent: "Register",
|
|
MailTitle: "Title",
|
|
TargetEndpoint: "/target",
|
|
IdentityRetrieverFunc: retriever,
|
|
}
|
|
}
|
|
|
|
func defaultRetriever(ctx *middlewares.AutheliaCtx) (*session.Identity, error) {
|
|
return &session.Identity{
|
|
Username: "john",
|
|
Email: "john@example.com",
|
|
}, nil
|
|
}
|
|
|
|
func TestShouldSkipStartIdentityVerificationIf2FASkipEnabled(t *testing.T) {
|
|
testCases := []bool{true, false}
|
|
for _, testCaseSkipEnabled := range testCases {
|
|
t.Run(fmt.Sprintf("SkipIfAuthLevelTwoFactor=%t", testCaseSkipEnabled), func(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
defer mock.Close()
|
|
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
|
|
|
if testCaseSkipEnabled == false {
|
|
mock.StorageMock.EXPECT().
|
|
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
Return(nil)
|
|
mock.NotifierMock.EXPECT().
|
|
Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()).
|
|
Return(nil)
|
|
}
|
|
|
|
userSession := mock.Ctx.GetSession()
|
|
userSession.AuthenticationLevel = authentication.TwoFactor
|
|
assert.NoError(t, mock.Ctx.SaveSession(userSession))
|
|
|
|
args := newArgs(defaultRetriever)
|
|
args.IdentityVerificationCommonArgs.SkipIfAuthLevelTwoFactor = testCaseSkipEnabled
|
|
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
defer mock.Close()
|
|
|
|
retriever := func(ctx *middlewares.AutheliaCtx) (*session.Identity, error) {
|
|
return nil, fmt.Errorf("User does not have any email")
|
|
}
|
|
|
|
middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
defer mock.Close()
|
|
|
|
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
|
|
mock.StorageMock.EXPECT().
|
|
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
Return(fmt.Errorf("cannot save"))
|
|
|
|
args := newArgs(defaultRetriever)
|
|
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func TestShouldFailSendingAnEmail(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
defer mock.Close()
|
|
|
|
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
|
|
|
mock.StorageMock.EXPECT().
|
|
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
Return(nil)
|
|
|
|
mock.NotifierMock.EXPECT().
|
|
Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()).
|
|
Return(fmt.Errorf("no notif"))
|
|
|
|
args := newArgs(defaultRetriever)
|
|
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
defer mock.Close()
|
|
|
|
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
|
|
|
mock.StorageMock.EXPECT().
|
|
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
Return(nil)
|
|
|
|
args := newArgs(defaultRetriever)
|
|
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
|
|
mock := mocks.NewMockAutheliaCtx(t)
|
|
|
|
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
|
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
|
|
|
mock.StorageMock.EXPECT().
|
|
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
Return(nil)
|
|
|
|
mock.NotifierMock.EXPECT().
|
|
Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()).
|
|
Return(nil)
|
|
|
|
args := newArgs(defaultRetriever)
|
|
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
|
|
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
|
|
defer mock.Close()
|
|
}
|
|
|
|
// Test Finish process.
|
|
type IdentityVerificationFinishProcess struct {
|
|
suite.Suite
|
|
|
|
mock *mocks.MockAutheliaCtx
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) SetupTest() {
|
|
s.mock = mocks.NewMockAutheliaCtx(s.T())
|
|
|
|
s.mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TearDownTest() {
|
|
s.mock.Close()
|
|
}
|
|
|
|
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification model.IdentityVerification) {
|
|
verification = model.NewIdentityVerification(uuid.New(), username, action, ctx.Ctx.RemoteIP())
|
|
|
|
verification.ExpiresAt = expiresAt
|
|
|
|
claims := verification.ToIdentityVerificationClaim()
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
|
|
|
|
return ss, verification
|
|
}
|
|
|
|
func next(ctx *middlewares.AutheliaCtx, username string) {}
|
|
|
|
func newFinishArgs() middlewares.IdentityVerificationFinishArgs {
|
|
return middlewares.IdentityVerificationFinishArgs{
|
|
ActionClaim: "EXP_ACTION",
|
|
IsTokenUserValidFunc: func(ctx *middlewares.AutheliaCtx, username string) bool { return true },
|
|
}
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfJSONBodyIsMalformed() {
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "unexpected end of JSON input", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided() {
|
|
s.mock.Ctx.Request.SetBodyString("{}")
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "No token provided", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() {
|
|
token, verification := createToken(s.mock, "john", "Login",
|
|
time.Now().Add(1*time.Minute))
|
|
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(false, nil)
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "The identity verification token has already been used")
|
|
assert.Equal(s.T(), "Token is not in DB, it might have already been used", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
|
|
s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}")
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "Cannot parse token", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
|
args := newArgs(defaultRetriever)
|
|
token, _ := createToken(s.mock, "john", args.ActionClaim,
|
|
time.Now().Add(-1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "The identity verification token has expired")
|
|
assert.Equal(s.T(), "Token expired", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
|
token, verification := createToken(s.mock, "", "",
|
|
time.Now().Add(1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(true, nil)
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "This token has not been generated for this kind of action", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
|
token, verification := createToken(s.mock, "harry", "EXP_ACTION",
|
|
time.Now().Add(1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(true, nil)
|
|
|
|
args := newFinishArgs()
|
|
args.IsTokenUserValidFunc = func(ctx *middlewares.AutheliaCtx, username string) bool { return false }
|
|
middlewares.IdentityVerificationFinish(args, next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "This token has not been generated for this user", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() {
|
|
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
|
time.Now().Add(1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(true, nil)
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
|
Return(fmt.Errorf("cannot remove"))
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
s.mock.Assert200KO(s.T(), "Operation failed")
|
|
assert.Equal(s.T(), "cannot remove", s.mock.Hook.LastEntry().Message)
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() {
|
|
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
|
time.Now().Add(1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(true, nil)
|
|
|
|
s.mock.StorageMock.EXPECT().
|
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
|
Return(nil)
|
|
|
|
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
|
|
|
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
|
}
|
|
|
|
func (s *IdentityVerificationFinishProcess) TestShouldSkipIf2FASkipEnabled() {
|
|
testCases := []bool{true, false}
|
|
for _, testCaseSkipEnabled := range testCases {
|
|
s.Run(fmt.Sprintf("SkipIfAuthLevelTwoFactor=%t", testCaseSkipEnabled), func() {
|
|
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
|
time.Now().Add(1*time.Minute))
|
|
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
|
|
if testCaseSkipEnabled == false {
|
|
s.mock.StorageMock.EXPECT().
|
|
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
Return(true, nil)
|
|
s.mock.StorageMock.EXPECT().
|
|
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
|
Return(nil)
|
|
}
|
|
|
|
userSession := s.mock.Ctx.GetSession()
|
|
userSession.AuthenticationLevel = authentication.TwoFactor
|
|
assert.NoError(s.T(), s.mock.Ctx.SaveSession(userSession))
|
|
|
|
args := newFinishArgs()
|
|
args.IdentityVerificationCommonArgs.SkipIfAuthLevelTwoFactor = testCaseSkipEnabled
|
|
middlewares.IdentityVerificationFinish(args, next)(s.mock.Ctx)
|
|
|
|
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRunIdentityVerificationFinish(t *testing.T) {
|
|
s := new(IdentityVerificationFinishProcess)
|
|
suite.Run(t, s)
|
|
}
|