diff --git a/internal/handlers/handler_authz.go b/internal/handlers/handler_authz.go index c143bd3bc..88596552c 100644 --- a/internal/handlers/handler_authz.go +++ b/internal/handlers/handler_authz.go @@ -112,7 +112,7 @@ func (authz *Authz) getAutheliaURL(ctx *middlewares.AutheliaCtx, provider *sessi return nil, err } - if autheliaURL != nil { + if autheliaURL != nil || authz.legacy { return autheliaURL, nil } diff --git a/internal/handlers/handler_authz_impl_legacy_test.go b/internal/handlers/handler_authz_impl_legacy_test.go index a76936529..541920b6d 100644 --- a/internal/handlers/handler_authz_impl_legacy_test.go +++ b/internal/handlers/handler_authz_impl_legacy_test.go @@ -139,7 +139,7 @@ func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsOverrideAutheliaURLDeny() { } } -func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsMissingAutheliaURLDeny() { +func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsMissingAutheliaURLBypassStatus200() { for _, method := range testRequestMethods { s.T().Run(fmt.Sprintf("Method%s", method), func(t *testing.T) { for _, targetURI := range []*url.URL{ @@ -163,6 +163,38 @@ func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsMissingAutheliaURLDeny() { authz.Handler(mock.Ctx) + assert.Equal(t, fasthttp.StatusOK, mock.Ctx.Response.StatusCode()) + assert.Equal(t, "", string(mock.Ctx.Response.Header.Peek(fasthttp.HeaderLocation))) + }) + } + }) + } +} + +func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsMissingAutheliaURLOneFactorStatus401() { + for _, method := range testRequestMethods { + s.T().Run(fmt.Sprintf("Method%s", method), func(t *testing.T) { + for _, targetURI := range []*url.URL{ + s.RequireParseRequestURI("https://one-factor.example.com"), + s.RequireParseRequestURI("https://one-factor.example.com/subpath"), + s.RequireParseRequestURI("https://one-factor.example2.com"), + s.RequireParseRequestURI("https://one-factor.example2.com/subpath"), + } { + t.Run(targetURI.String(), func(t *testing.T) { + authz := s.Builder().Build() + + mock := mocks.NewMockAutheliaCtx(t) + + defer mock.Close() + + mock.Ctx.Request.Header.Set("X-Forwarded-Method", method) + mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedProto, targetURI.Scheme) + mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedHost, targetURI.Host) + mock.Ctx.Request.Header.Set("X-Forwarded-Uri", targetURI.Path) + mock.Ctx.Request.Header.Set(fasthttp.HeaderAccept, "text/html; charset=utf-8") + + authz.Handler(mock.Ctx) + assert.Equal(t, fasthttp.StatusUnauthorized, mock.Ctx.Response.StatusCode()) assert.Equal(t, "", string(mock.Ctx.Response.Header.Peek(fasthttp.HeaderLocation))) }) @@ -171,6 +203,47 @@ func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsMissingAutheliaURLDeny() { } } +func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsRDAutheliaURLOneFactorStatus302Or303() { + for _, method := range testRequestMethods { + s.T().Run(fmt.Sprintf("Method%s", method), func(t *testing.T) { + for _, targetURI := range []*url.URL{ + s.RequireParseRequestURI("https://one-factor.example.com/"), + s.RequireParseRequestURI("https://one-factor.example.com/subpath"), + } { + t.Run(targetURI.String(), func(t *testing.T) { + authz := s.Builder().Build() + + mock := mocks.NewMockAutheliaCtx(t) + + defer mock.Close() + + mock.Ctx.Request.Header.Set("X-Forwarded-Method", method) + mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedProto, targetURI.Scheme) + mock.Ctx.Request.Header.Set(fasthttp.HeaderXForwardedHost, targetURI.Host) + mock.Ctx.Request.Header.Set("X-Forwarded-Uri", targetURI.Path) + mock.Ctx.Request.Header.Set(fasthttp.HeaderAccept, "text/html; charset=utf-8") + mock.Ctx.Request.SetRequestURI("/api/verify?rd=https%3A%2F%2Fauth.example.com") + + authz.Handler(mock.Ctx) + + switch method { + case fasthttp.MethodGet, fasthttp.MethodOptions: + assert.Equal(t, fasthttp.StatusFound, mock.Ctx.Response.StatusCode()) + default: + assert.Equal(t, fasthttp.StatusSeeOther, mock.Ctx.Response.StatusCode()) + } + + query := &url.Values{} + query.Set("rd", targetURI.String()) + query.Set("rm", method) + + assert.Equal(t, fmt.Sprintf("https://auth.example.com/?%s", query.Encode()), string(mock.Ctx.Response.Header.Peek(fasthttp.HeaderLocation))) + }) + } + }) + } +} + func (s *LegacyAuthzSuite) TestShouldHandleAllMethodsXHRDeny() { for _, method := range testRequestMethods { s.T().Run(fmt.Sprintf("Method%s", method), func(t *testing.T) {