package middlewares_test import ( "fmt" "net/mail" "testing" "time" "github.com/golang-jwt/jwt/v4" "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/session" ) const testJWTSecret = "abc" func newArgs(retriever func(ctx *middlewares.AutheliaCtx) (*session.Identity, error)) middlewares.IdentityVerificationStartArgs { return middlewares.IdentityVerificationStartArgs{ ActionClaim: "Claim", MailButtonContent: "Register", MailTitle: "Title", TargetEndpoint: "/target", IdentityRetrieverFunc: retriever, } } func defaultRetriever(ctx *middlewares.AutheliaCtx) (*session.Identity, error) { return &session.Identity{ Username: "john", Email: "john@example.com", }, nil } func TestShouldSkipStartIdentityVerificationIf2FASkipEnabled(t *testing.T) { testCases := []bool{true, false} for _, testCaseSkipEnabled := range testCases { t.Run(fmt.Sprintf("SkipIfAuthLevelTwoFactor=%t", testCaseSkipEnabled), func(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") if testCaseSkipEnabled == false { mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) mock.NotifierMock.EXPECT(). Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()). Return(nil) } userSession := mock.Ctx.GetSession() userSession.AuthenticationLevel = authentication.TwoFactor assert.NoError(t, mock.Ctx.SaveSession(userSession)) args := newArgs(defaultRetriever) args.IdentityVerificationCommonArgs.SkipIfAuthLevelTwoFactor = testCaseSkipEnabled middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) }) } } func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() retriever := func(ctx *middlewares.AutheliaCtx) (*session.Identity, error) { return nil, fmt.Errorf("User does not have any email") } middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message) } func TestShouldFailIfJWTCannotBeSaved(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(fmt.Errorf("cannot save")) args := newArgs(defaultRetriever) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message) } func TestShouldFailSendingAnEmail(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) mock.NotifierMock.EXPECT(). Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()). Return(fmt.Errorf("no notif")) args := newArgs(defaultRetriever) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "no notif", mock.Hook.LastEntry().Message) } func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) args := newArgs(defaultRetriever) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message) } func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) mock.NotifierMock.EXPECT(). Send(gomock.Eq(mail.Address{Address: "john@example.com"}), gomock.Eq("Title"), gomock.Any(), gomock.Any()). Return(nil) args := newArgs(defaultRetriever) middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) defer mock.Close() } // Test Finish process. type IdentityVerificationFinishProcess struct { suite.Suite mock *mocks.MockAutheliaCtx } func (s *IdentityVerificationFinishProcess) SetupTest() { s.mock = mocks.NewMockAutheliaCtx(s.T()) s.mock.Ctx.Configuration.JWTSecret = testJWTSecret } func (s *IdentityVerificationFinishProcess) TearDownTest() { s.mock.Close() } func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification model.IdentityVerification) { verification = model.NewIdentityVerification(uuid.New(), username, action, ctx.Ctx.RemoteIP()) verification.ExpiresAt = expiresAt claims := verification.ToIdentityVerificationClaim() token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret)) return ss, verification } func next(ctx *middlewares.AutheliaCtx, username string) {} func newFinishArgs() middlewares.IdentityVerificationFinishArgs { return middlewares.IdentityVerificationFinishArgs{ ActionClaim: "EXP_ACTION", IsTokenUserValidFunc: func(ctx *middlewares.AutheliaCtx, username string) bool { return true }, } } func (s *IdentityVerificationFinishProcess) TestShouldFailIfJSONBodyIsMalformed() { middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "unexpected end of JSON input", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided() { s.mock.Ctx.Request.SetBodyString("{}") middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "No token provided", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() { token, verification := createToken(s.mock, "john", "Login", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(false, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "The identity verification token has already been used") assert.Equal(s.T(), "Token is not in DB, it might have already been used", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() { s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}") middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "Cannot parse token", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() { args := newArgs(defaultRetriever) token, _ := createToken(s.mock, "john", args.ActionClaim, time.Now().Add(-1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "The identity verification token has expired") assert.Equal(s.T(), "Token expired", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() { token, verification := createToken(s.mock, "", "", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "This token has not been generated for this kind of action", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() { token, verification := createToken(s.mock, "harry", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) args := newFinishArgs() args.IsTokenUserValidFunc = func(ctx *middlewares.AutheliaCtx, username string) bool { return false } middlewares.IdentityVerificationFinish(args, next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "This token has not been generated for this user", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() { token, verification := createToken(s.mock, "john", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) s.mock.StorageMock.EXPECT(). ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(fmt.Errorf("cannot remove")) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed") assert.Equal(s.T(), "cannot remove", s.mock.Hook.LastEntry().Message) } func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() { token, verification := createToken(s.mock, "john", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) s.mock.StorageMock.EXPECT(). ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) } func (s *IdentityVerificationFinishProcess) TestShouldSkipIf2FASkipEnabled() { testCases := []bool{true, false} for _, testCaseSkipEnabled := range testCases { s.Run(fmt.Sprintf("SkipIfAuthLevelTwoFactor=%t", testCaseSkipEnabled), func() { token, verification := createToken(s.mock, "john", "EXP_ACTION", time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) if testCaseSkipEnabled == false { s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) s.mock.StorageMock.EXPECT(). ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(model.NewNullIP(s.mock.Ctx.RemoteIP()))). Return(nil) } userSession := s.mock.Ctx.GetSession() userSession.AuthenticationLevel = authentication.TwoFactor assert.NoError(s.T(), s.mock.Ctx.SaveSession(userSession)) args := newFinishArgs() args.IdentityVerificationCommonArgs.SkipIfAuthLevelTwoFactor = testCaseSkipEnabled middlewares.IdentityVerificationFinish(args, next)(s.mock.Ctx) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) }) } } func TestRunIdentityVerificationFinish(t *testing.T) { s := new(IdentityVerificationFinishProcess) suite.Run(t, s) }