feat(server): handle head method (#5003)
This implements some HEAD method handlers for various static resources and the /api/health endpoint.pull/5004/head
parent
f68e5cf957
commit
a345490826
|
@ -76,6 +76,14 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/handlers.configuration.PasswordPolicyConfigurationBody'
|
$ref: '#/components/schemas/handlers.configuration.PasswordPolicyConfigurationBody'
|
||||||
/api/health:
|
/api/health:
|
||||||
|
head:
|
||||||
|
tags:
|
||||||
|
- State
|
||||||
|
summary: Application Health
|
||||||
|
description: The health check endpoint provides information about the health of Authelia.
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Successful Operation
|
||||||
get:
|
get:
|
||||||
tags:
|
tags:
|
||||||
- State
|
- State
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
|
@ -11,20 +9,11 @@ import (
|
||||||
// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||||
// OpenID Connect Discovery 1.0 metadata.
|
// OpenID Connect Discovery 1.0 metadata.
|
||||||
//
|
//
|
||||||
// https://datatracker.ietf.org/doc/html/rfc5785
|
// RFC5785: Defining Well-Known URIs (https://datatracker.ietf.org/doc/html/rfc5785)
|
||||||
//
|
//
|
||||||
// https://openid.net/specs/openid-connect-discovery-1_0.html
|
// OpenID Connect Discovery 1.0 (https://openid.net/specs/openid-connect-discovery-1_0.html)
|
||||||
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||||
var (
|
if err := ctx.ReplyJSON(ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(ctx.RootURL().String()), fasthttp.StatusOK); err != nil {
|
||||||
issuer *url.URL
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
issuer = ctx.RootURL()
|
|
||||||
|
|
||||||
wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer.String())
|
|
||||||
|
|
||||||
if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
|
|
||||||
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
|
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
|
||||||
|
|
||||||
// TODO: Determine if this is the appropriate error code here.
|
// TODO: Determine if this is the appropriate error code here.
|
||||||
|
@ -37,20 +26,11 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||||
// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||||
// OAuth 2.0 Authorization Server Metadata (RFC8414).
|
// OAuth 2.0 Authorization Server Metadata (RFC8414).
|
||||||
//
|
//
|
||||||
// https://datatracker.ietf.org/doc/html/rfc5785
|
// RFC5785: Defining Well-Known URIs (https://datatracker.ietf.org/doc/html/rfc5785)
|
||||||
//
|
//
|
||||||
// https://datatracker.ietf.org/doc/html/rfc8414
|
// RFC8414: OAuth 2.0 Authorization Server Metadata (https://datatracker.ietf.org/doc/html/rfc8414)
|
||||||
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||||
var (
|
if err := ctx.ReplyJSON(ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(ctx.RootURL().String()), fasthttp.StatusOK); err != nil {
|
||||||
issuer *url.URL
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
issuer = ctx.RootURL()
|
|
||||||
|
|
||||||
wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer.String())
|
|
||||||
|
|
||||||
if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
|
|
||||||
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
|
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
|
||||||
|
|
||||||
// TODO: Determine if this is the appropriate error code here.
|
// TODO: Determine if this is the appropriate error code here.
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
@ -64,9 +65,17 @@ func newPublicHTMLEmbeddedHandler() fasthttp.RequestHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.SetContentType(contentType)
|
ctx.SetContentType(contentType)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.IsHead():
|
||||||
|
ctx.Response.ResetBody()
|
||||||
|
ctx.Response.SkipBody = true
|
||||||
|
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(len(data)))
|
||||||
|
default:
|
||||||
ctx.SetBody(data)
|
ctx.SetBody(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newLocalesPathResolver() func(ctx *fasthttp.RequestCtx) (supported bool, asset string) {
|
func newLocalesPathResolver() func(ctx *fasthttp.RequestCtx) (supported bool, asset string) {
|
||||||
var (
|
var (
|
||||||
|
@ -182,9 +191,16 @@ func newLocalesEmbeddedHandler() (handler fasthttp.RequestHandler) {
|
||||||
|
|
||||||
middlewares.SetContentTypeApplicationJSON(ctx)
|
middlewares.SetContentTypeApplicationJSON(ctx)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.IsHead():
|
||||||
|
ctx.Response.ResetBody()
|
||||||
|
ctx.Response.SkipBody = true
|
||||||
|
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(len(data)))
|
||||||
|
default:
|
||||||
ctx.SetBody(data)
|
ctx.SetBody(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getEmbedETags(embedFS embed.FS, root string, etags map[string][]byte) {
|
func getEmbedETags(embedFS embed.FS, root string, etags map[string][]byte) {
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
@ -77,10 +78,10 @@ func handleError() func(ctx *fasthttp.RequestCtx, err error) {
|
||||||
|
|
||||||
func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
path := strings.ToLower(string(ctx.Path()))
|
uri := strings.ToLower(string(ctx.Path()))
|
||||||
|
|
||||||
for i := 0; i < len(dirsHTTPServer); i++ {
|
for i := 0; i < len(dirsHTTPServer); i++ {
|
||||||
if path == dirsHTTPServer[i].name || strings.HasPrefix(path, dirsHTTPServer[i].prefix) {
|
if uri == dirsHTTPServer[i].name || strings.HasPrefix(uri, dirsHTTPServer[i].prefix) {
|
||||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -91,6 +92,13 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleMethodNotAllowed(ctx *fasthttp.RequestCtx) {
|
||||||
|
middlewares.SetContentTypeTextPlain(ctx)
|
||||||
|
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusMethodNotAllowed)
|
||||||
|
ctx.SetBodyString(fmt.Sprintf("%d %s", fasthttp.StatusMethodNotAllowed, fasthttp.StatusMessage(fasthttp.StatusMethodNotAllowed)))
|
||||||
|
}
|
||||||
|
|
||||||
//nolint:gocyclo
|
//nolint:gocyclo
|
||||||
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||||
log := logging.Logger()
|
log := logging.Logger()
|
||||||
|
@ -115,29 +123,45 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
||||||
r := router.New()
|
r := router.New()
|
||||||
|
|
||||||
// Static Assets.
|
// Static Assets.
|
||||||
|
r.HEAD("/", bridge(serveIndexHandler))
|
||||||
r.GET("/", bridge(serveIndexHandler))
|
r.GET("/", bridge(serveIndexHandler))
|
||||||
|
|
||||||
for _, f := range filesRoot {
|
for _, f := range filesRoot {
|
||||||
|
r.HEAD("/"+f, handlerPublicHTML)
|
||||||
r.GET("/"+f, handlerPublicHTML)
|
r.GET("/"+f, handlerPublicHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.HEAD("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML))
|
||||||
r.GET("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML))
|
r.GET("/favicon.ico", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerPublicHTML))
|
||||||
|
|
||||||
|
r.HEAD("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML))
|
||||||
r.GET("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML))
|
r.GET("/static/media/logo.png", middlewares.AssetOverride(config.Server.AssetPath, 2, handlerPublicHTML))
|
||||||
|
|
||||||
|
r.HEAD("/static/{filepath:*}", handlerPublicHTML)
|
||||||
r.GET("/static/{filepath:*}", handlerPublicHTML)
|
r.GET("/static/{filepath:*}", handlerPublicHTML)
|
||||||
|
|
||||||
// Locales.
|
// Locales.
|
||||||
|
r.HEAD("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
||||||
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-zA-Z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
||||||
|
|
||||||
|
r.HEAD("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
||||||
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverride(config.Server.AssetPath, 0, handlerLocales))
|
||||||
|
|
||||||
// Swagger.
|
// Swagger.
|
||||||
|
r.HEAD("/api/", bridge(serveOpenAPIHandler))
|
||||||
r.GET("/api/", bridge(serveOpenAPIHandler))
|
r.GET("/api/", bridge(serveOpenAPIHandler))
|
||||||
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
|
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
|
||||||
|
r.HEAD("/api/index.html", bridge(serveOpenAPIHandler))
|
||||||
r.GET("/api/index.html", bridge(serveOpenAPIHandler))
|
r.GET("/api/index.html", bridge(serveOpenAPIHandler))
|
||||||
r.OPTIONS("/api/index.html", policyCORSPublicGET.HandleOPTIONS)
|
r.OPTIONS("/api/index.html", policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
|
||||||
|
r.HEAD("/api/openapi.yml", policyCORSPublicGET.Middleware(bridge(serveOpenAPISpecHandler)))
|
||||||
r.GET("/api/openapi.yml", policyCORSPublicGET.Middleware(bridge(serveOpenAPISpecHandler)))
|
r.GET("/api/openapi.yml", policyCORSPublicGET.Middleware(bridge(serveOpenAPISpecHandler)))
|
||||||
r.OPTIONS("/api/openapi.yml", policyCORSPublicGET.HandleOPTIONS)
|
r.OPTIONS("/api/openapi.yml", policyCORSPublicGET.HandleOPTIONS)
|
||||||
|
|
||||||
for _, file := range filesSwagger {
|
for _, file := range filesSwagger {
|
||||||
|
r.HEAD("/api/"+file, handlerPublicHTML)
|
||||||
r.GET("/api/"+file, handlerPublicHTML)
|
r.GET("/api/"+file, handlerPublicHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,7 +174,9 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
||||||
WithPostMiddlewares(middlewares.Require1FA).
|
WithPostMiddlewares(middlewares.Require1FA).
|
||||||
Build()
|
Build()
|
||||||
|
|
||||||
|
r.HEAD("/api/health", middlewareAPI(handlers.HealthGET))
|
||||||
r.GET("/api/health", middlewareAPI(handlers.HealthGET))
|
r.GET("/api/health", middlewareAPI(handlers.HealthGET))
|
||||||
|
|
||||||
r.GET("/api/state", middlewareAPI(handlers.StateGET))
|
r.GET("/api/state", middlewareAPI(handlers.StateGET))
|
||||||
|
|
||||||
r.GET("/api/configuration", middleware1FA(handlers.ConfigurationGET))
|
r.GET("/api/configuration", middleware1FA(handlers.ConfigurationGET))
|
||||||
|
@ -356,7 +382,7 @@ func handleRouter(config *schema.Configuration, providers middlewares.Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.HandleMethodNotAllowed = true
|
r.HandleMethodNotAllowed = true
|
||||||
r.MethodNotAllowed = handlers.Status(fasthttp.StatusMethodNotAllowed)
|
r.MethodNotAllowed = handleMethodNotAllowed
|
||||||
r.NotFound = handleNotFound(bridge(serveIndexHandler))
|
r.NotFound = handleNotFound(bridge(serveIndexHandler))
|
||||||
|
|
||||||
handler := middlewares.LogRequest(r.Handler)
|
handler := middlewares.LogRequest(r.Handler)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -25,7 +26,7 @@ import (
|
||||||
// and generate a nonce to support a restrictive CSP while using material-ui.
|
// and generate a nonce to support a restrictive CSP while using material-ui.
|
||||||
func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
|
func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
|
||||||
isDevEnvironment := os.Getenv(environment) == dev
|
isDevEnvironment := os.Getenv(environment) == dev
|
||||||
ext := filepath.Ext(t.Name())
|
ext := path.Ext(t.Name())
|
||||||
|
|
||||||
return func(ctx *middlewares.AutheliaCtx) {
|
return func(ctx *middlewares.AutheliaCtx) {
|
||||||
var err error
|
var err error
|
||||||
|
@ -67,18 +68,34 @@ func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middle
|
||||||
rememberMe = strconv.FormatBool(!provider.Config.DisableRememberMe)
|
rememberMe = strconv.FormatBool(!provider.Config.DisableRememberMe)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = t.Execute(ctx.Response.BodyWriter(), opts.CommonData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce, logoOverride, rememberMe)); err != nil {
|
data := &bytes.Buffer{}
|
||||||
ctx.RequestCtx.Error("an error occurred", 503)
|
|
||||||
|
if err = t.Execute(data, opts.CommonData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce, logoOverride, rememberMe)); err != nil {
|
||||||
|
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
|
||||||
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")
|
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.IsHead():
|
||||||
|
ctx.Response.ResetBody()
|
||||||
|
ctx.Response.SkipBody = true
|
||||||
|
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(data.Len()))
|
||||||
|
default:
|
||||||
|
if _, err = data.WriteTo(ctx.Response.BodyWriter()); err != nil {
|
||||||
|
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
|
||||||
|
ctx.Logger.WithError(err).Errorf("Error occcurred writing body")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeTemplatedOpenAPI serves templated OpenAPI related files.
|
// ServeTemplatedOpenAPI serves templated OpenAPI related files.
|
||||||
func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
|
func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) middlewares.RequestHandler {
|
||||||
ext := filepath.Ext(t.Name())
|
ext := path.Ext(t.Name())
|
||||||
|
|
||||||
spec := ext == extYML
|
spec := ext == extYML
|
||||||
|
|
||||||
|
@ -103,12 +120,28 @@ func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) mid
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if err = t.Execute(ctx.Response.BodyWriter(), opts.OpenAPIData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce)); err != nil {
|
data := &bytes.Buffer{}
|
||||||
ctx.RequestCtx.Error("an error occurred", 503)
|
|
||||||
|
if err = t.Execute(data, opts.OpenAPIData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce)); err != nil {
|
||||||
|
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
|
||||||
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")
|
ctx.Logger.WithError(err).Errorf("Error occcurred rendering template")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.IsHead():
|
||||||
|
ctx.Response.ResetBody()
|
||||||
|
ctx.Response.SkipBody = true
|
||||||
|
ctx.Response.Header.Set(fasthttp.HeaderContentLength, strconv.Itoa(data.Len()))
|
||||||
|
default:
|
||||||
|
if _, err = data.WriteTo(ctx.Response.BodyWriter()); err != nil {
|
||||||
|
ctx.RequestCtx.Error("an error occurred", fasthttp.StatusServiceUnavailable)
|
||||||
|
ctx.Logger.WithError(err).Errorf("Error occcurred writing body")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,6 +172,11 @@ func ETagRootURL(next middlewares.RequestHandler) middlewares.RequestHandler {
|
||||||
|
|
||||||
next(ctx)
|
next(ctx)
|
||||||
|
|
||||||
|
if ctx.Response.SkipBody || ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||||
|
// Skip generating the ETag as the response body should be empty.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
||||||
h.Write(ctx.Response.Body())
|
h.Write(ctx.Response.Body())
|
||||||
|
|
|
@ -7,61 +7,69 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WARNING: This scenario is intended to be used with TLS enabled in the authelia backend.
|
// WARNING: This scenario is intended to be used with TLS enabled in the authelia backend.
|
||||||
|
|
||||||
type BackendProtectionScenario struct {
|
type BackendProtectionScenario struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
|
|
||||||
|
client *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackendProtectionScenario() *BackendProtectionScenario {
|
func NewBackendProtectionScenario() *BackendProtectionScenario {
|
||||||
return &BackendProtectionScenario{}
|
return &BackendProtectionScenario{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BackendProtectionScenario) SetupSuite() {
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in suites. Not used in production.
|
||||||
|
}
|
||||||
|
|
||||||
|
s.client = &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BackendProtectionScenario) AssertRequestStatusCode(method, url string, expectedStatusCode int) {
|
func (s *BackendProtectionScenario) AssertRequestStatusCode(method, url string, expectedStatusCode int) {
|
||||||
s.Run(url, func() {
|
s.Run(url, func() {
|
||||||
req, err := http.NewRequest(method, url, nil)
|
req, err := http.NewRequest(method, url, nil)
|
||||||
s.Assert().NoError(err)
|
s.Assert().NoError(err)
|
||||||
|
|
||||||
tr := &http.Transport{
|
res, err := s.client.Do(req)
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in suites. Not used in production.
|
|
||||||
}
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: tr,
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
}
|
|
||||||
res, err := client.Do(req)
|
|
||||||
s.Assert().NoError(err)
|
s.Assert().NoError(err)
|
||||||
s.Assert().Equal(expectedStatusCode, res.StatusCode)
|
s.Assert().Equal(expectedStatusCode, res.StatusCode)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackendProtectionScenario) TestProtectionOfBackendEndpoints() {
|
func (s *BackendProtectionScenario) TestProtectionOfBackendEndpoints() {
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/totp", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/totp", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/webauthn/assertion", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/webauthn/assertion", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/webauthn/attestation", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/webauthn/attestation", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/user/info/2fa_method", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/user/info/2fa_method", AutheliaBaseURL), 403)
|
||||||
|
|
||||||
s.AssertRequestStatusCode("GET", fmt.Sprintf("%s/api/user/info", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodGet, fmt.Sprintf("%s/api/user/info", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("GET", fmt.Sprintf("%s/api/configuration", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodGet, fmt.Sprintf("%s/api/configuration", AutheliaBaseURL), 403)
|
||||||
|
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/totp/identity/start", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/totp/identity/start", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/totp/identity/finish", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/totp/identity/finish", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/webauthn/identity/start", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/webauthn/identity/start", AutheliaBaseURL), 403)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/secondfactor/webauthn/identity/finish", AutheliaBaseURL), 403)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/secondfactor/webauthn/identity/finish", AutheliaBaseURL), 403)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackendProtectionScenario) TestInvalidEndpointsReturn404() {
|
func (s *BackendProtectionScenario) TestInvalidEndpointsReturn404() {
|
||||||
s.AssertRequestStatusCode("GET", fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodGet, fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
||||||
s.AssertRequestStatusCode("HEAD", fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodHead, fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/not_existing", AutheliaBaseURL), 404)
|
||||||
|
|
||||||
s.AssertRequestStatusCode("GET", fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodGet, fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
||||||
s.AssertRequestStatusCode("HEAD", fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodHead, fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
||||||
s.AssertRequestStatusCode("POST", fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
s.AssertRequestStatusCode(fasthttp.MethodPost, fmt.Sprintf("%s/api/not_existing/second", AutheliaBaseURL), 404)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackendProtectionScenario) TestInvalidEndpointsReturn405() {
|
func (s *BackendProtectionScenario) TestInvalidEndpointsReturn405() {
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
package suites
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewRequestMethodScenario() *RequestMethodScenario {
|
||||||
|
return &RequestMethodScenario{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestMethodScenario struct {
|
||||||
|
suite.Suite
|
||||||
|
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RequestMethodScenario) SetupSuite() {
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in suites. Not used in production.
|
||||||
|
}
|
||||||
|
|
||||||
|
s.client = &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RequestMethodScenario) TestShouldRespondWithAppropriateMethodNotAllowedHeaders() {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
uri string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{"RootPathShouldShowAllowedMethodsOnInvalidRequest", fasthttp.MethodPost, AutheliaBaseURL, []string{fasthttp.MethodGet, fasthttp.MethodHead, fasthttp.MethodOptions}},
|
||||||
|
{"OpenAPISpecificationShouldShowAllowedMethodsOnInvalidRequest", fasthttp.MethodPost, fmt.Sprintf("%s/api/openapi.yml", AutheliaBaseURL), []string{fasthttp.MethodGet, fasthttp.MethodHead, fasthttp.MethodOptions}},
|
||||||
|
{"LocalesShouldShowAllowedMethodsOnInvalidRequest", fasthttp.MethodPost, fmt.Sprintf("%s/locales/en/portal.json", AutheliaBaseURL), []string{fasthttp.MethodGet, fasthttp.MethodHead, fasthttp.MethodOptions}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
s.Run(tc.name, func() {
|
||||||
|
req, err := http.NewRequest(tc.method, tc.uri, nil)
|
||||||
|
s.Assert().NoError(err)
|
||||||
|
|
||||||
|
res, err := s.client.Do(req)
|
||||||
|
|
||||||
|
s.Assert().NoError(err)
|
||||||
|
s.Assert().Equal(fasthttp.StatusMethodNotAllowed, res.StatusCode)
|
||||||
|
s.Assert().Equal(strings.Join(tc.expected, ", "), res.Header.Get(fasthttp.HeaderAllow))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RequestMethodScenario) TestShouldRespondWithAppropriateResponseWithMethodHEAD() {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
uri string
|
||||||
|
expectedStatus int
|
||||||
|
expectedContentLength bool
|
||||||
|
}{
|
||||||
|
{"RootPathShouldShowContentLengthAndRespondOK", AutheliaBaseURL, fasthttp.StatusOK, true},
|
||||||
|
{"OpenAPISpecShouldShowContentLengthAndRespondOK", fmt.Sprintf("%s/api/openapi.yml", AutheliaBaseURL), fasthttp.StatusOK, true},
|
||||||
|
{"LocalesShouldShowContentLengthAndRespondOK", fmt.Sprintf("%s/locales/en/portal.json", AutheliaBaseURL), fasthttp.StatusOK, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
s.Run(tc.name, func() {
|
||||||
|
req, err := http.NewRequest(fasthttp.MethodHead, tc.uri, nil)
|
||||||
|
s.Assert().NoError(err)
|
||||||
|
|
||||||
|
res, err := s.client.Do(req)
|
||||||
|
|
||||||
|
s.Assert().NoError(err)
|
||||||
|
s.Assert().Equal(tc.expectedStatus, res.StatusCode)
|
||||||
|
|
||||||
|
if tc.expectedContentLength {
|
||||||
|
s.Assert().NotEqual(0, res.ContentLength)
|
||||||
|
} else {
|
||||||
|
s.Assert().Equal(0, res.ContentLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(res.Body)
|
||||||
|
|
||||||
|
s.Assert().NoError(err)
|
||||||
|
s.Assert().Len(data, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunRequestMethod(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping suite test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.Run(t, NewRequestMethodScenario())
|
||||||
|
}
|
|
@ -346,6 +346,10 @@ func (s *StandaloneSuite) TestResetPasswordScenario() {
|
||||||
suite.Run(s.T(), NewResetPasswordScenario())
|
suite.Run(s.T(), NewResetPasswordScenario())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StandaloneSuite) TestRequestMethodScenario() {
|
||||||
|
suite.Run(s.T(), NewRequestMethodScenario())
|
||||||
|
}
|
||||||
|
|
||||||
func (s *StandaloneSuite) TestAvailableMethodsScenario() {
|
func (s *StandaloneSuite) TestAvailableMethodsScenario() {
|
||||||
suite.Run(s.T(), NewAvailableMethodsScenario([]string{"TIME-BASED ONE-TIME PASSWORD", "SECURITY KEY - WEBAUTHN"}))
|
suite.Run(s.T(), NewAvailableMethodsScenario([]string{"TIME-BASED ONE-TIME PASSWORD", "SECURITY KEY - WEBAUTHN"}))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue