refactor(commands): services (#4914)

Misc refactoring of the services logic to simplify the
pull/4916/head
James Elliott 2023-02-11 21:45:26 +11:00 committed by GitHub
parent 1a5178a8a5
commit 2888ee7f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 348 additions and 246 deletions

View File

@ -772,3 +772,7 @@ Layouts:
ANSIC: Mon Jan _2 15:04:05 2006 ANSIC: Mon Jan _2 15:04:05 2006
Date: 2006-01-02` Date: 2006-01-02`
) )
const (
fmtLogServerListening = "Server is listening for %s connections on '%s' path '%s'"
)

View File

@ -9,7 +9,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/sync/errgroup"
"github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/authorization"
@ -35,14 +34,8 @@ import (
func NewCmdCtx() *CmdCtx { func NewCmdCtx() *CmdCtx {
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
group, ctx := errgroup.WithContext(ctx)
return &CmdCtx{ return &CmdCtx{
Context: ctx, Context: ctx,
cancel: cancel,
group: group,
log: logging.Logger(), log: logging.Logger(),
providers: middlewares.Providers{ providers: middlewares.Providers{
Random: &random.Cryptographical{}, Random: &random.Cryptographical{},
@ -55,9 +48,6 @@ func NewCmdCtx() *CmdCtx {
type CmdCtx struct { type CmdCtx struct {
context.Context context.Context
cancel context.CancelFunc
group *errgroup.Group
log *logrus.Logger log *logrus.Logger
config *schema.Configuration config *schema.Configuration

View File

@ -2,21 +2,13 @@ package commands
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/signal"
"path/filepath"
"strings" "strings"
"syscall"
"github.com/fsnotify/fsnotify"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/server"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -95,195 +87,11 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) {
ctx.cconfig = nil ctx.cconfig = nil
runServices(ctx) servicesRun(ctx)
return nil return nil
} }
//nolint:gocyclo // Complexity is required in this function.
func runServices(ctx *CmdCtx) {
defer ctx.cancel()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
var (
mainServer, metricsServer *fasthttp.Server
mainListener, metricsListener net.Listener
)
ctx.group.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
ctx.log.WithError(recoverErr(r)).Errorf("Server (main) critical error caught (recovered)")
}
}()
if mainServer, mainListener, err = server.CreateDefaultServer(*ctx.config, ctx.providers); err != nil {
ctx.log.WithError(err).Error("Create Server (main) returned error")
return err
}
if err = mainServer.Serve(mainListener); err != nil {
ctx.log.WithError(err).Error("Server (main) returned error")
return err
}
return nil
})
ctx.group.Go(func() (err error) {
if ctx.providers.Metrics == nil {
return nil
}
defer func() {
if r := recover(); r != nil {
ctx.log.WithError(recoverErr(r)).Errorf("Server (metrics) critical error caught (recovered)")
}
}()
if metricsServer, metricsListener, err = server.CreateMetricsServer(ctx.config.Telemetry.Metrics); err != nil {
ctx.log.WithError(err).Error("Create Server (metrics) returned error")
return err
}
if err = metricsServer.Serve(metricsListener); err != nil {
ctx.log.WithError(err).Error("Server (metrics) returned error")
return err
}
return nil
})
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
if watcher, err := runServiceFileWatcher(ctx, ctx.config.AuthenticationBackend.File.Path, provider); err != nil {
ctx.log.WithError(err).Errorf("File Watcher (user database) start returned error")
} else {
defer func(watcher *fsnotify.Watcher) {
if err := watcher.Close(); err != nil {
ctx.log.WithError(err).Errorf("File Watcher (user database) close returned error")
}
}(watcher)
}
}
select {
case s := <-quit:
switch s {
case syscall.SIGINT:
ctx.log.Debugf("Shutdown started due to SIGINT")
case syscall.SIGQUIT:
ctx.log.Debugf("Shutdown started due to SIGQUIT")
}
case <-ctx.Done():
ctx.log.Debugf("Shutdown started due to context completion")
}
ctx.cancel()
ctx.log.Infof("Shutting down")
var err error
if mainServer != nil {
if err = mainServer.Shutdown(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred shutting down the server")
}
}
if metricsServer != nil {
if err = metricsServer.Shutdown(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred shutting down the metrics server")
}
}
if err = ctx.providers.StorageProvider.Close(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred closing the database connection")
}
if err = ctx.group.Wait(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
}
}
type ReloadFilter func(path string) (skipped bool)
type ProviderReload interface {
Reload() (reloaded bool, err error)
}
func runServiceFileWatcher(ctx *CmdCtx, path string, reload ProviderReload) (watcher *fsnotify.Watcher, err error) {
if watcher, err = fsnotify.NewWatcher(); err != nil {
return nil, err
}
failed := make(chan struct{})
var directory, filename string
if path != "" {
directory, filename = filepath.Dir(path), filepath.Base(path)
}
ctx.group.Go(func() error {
for {
select {
case <-failed:
return nil
case event, ok := <-watcher.Events:
if !ok {
return nil
}
if filename != filepath.Base(event.Name) {
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Tracef("File modification detected to irrelevant file")
break
}
switch {
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("File modification detected")
switch reloaded, err := reload.Reload(); {
case err != nil:
ctx.log.WithField("file", event.Name).WithField("op", event.Op).WithError(err).Error("Error occurred reloading file")
case reloaded:
ctx.log.WithField("file", event.Name).Info("Reloaded file successfully")
default:
ctx.log.WithField("file", event.Name).Debug("Reload of file was triggered but it was skipped")
}
case event.Op&fsnotify.Remove == fsnotify.Remove:
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("Remove of file was detected")
}
case err, ok := <-watcher.Errors:
if !ok {
return nil
}
ctx.log.WithError(err).Errorf("Error while watching files")
}
}
})
if err := watcher.Add(directory); err != nil {
failed <- struct{}{}
return nil, err
}
ctx.log.WithField("directory", directory).WithField("file", filename).Debug("Directory is being watched for changes to the file")
return watcher, nil
}
func doStartupChecks(ctx *CmdCtx) { func doStartupChecks(ctx *CmdCtx) {
var ( var (
failures []string failures []string

View File

@ -0,0 +1,311 @@
package commands
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
"golang.org/x/sync/errgroup"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/server"
)
// NewServerService creates a new ServerService with the appropriate logger etc.
func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) {
return &ServerService{
server: server,
listener: listener,
paths: paths,
isTLS: isTLS,
log: log.WithFields(map[string]any{"service": "server", "server": name}),
}
}
// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc.
func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) {
if path == "" {
return nil, fmt.Errorf("path must be specified")
}
var info os.FileInfo
if info, err = os.Stat(path); err != nil {
return nil, fmt.Errorf("error stating file '%s': %w", path, err)
}
if path, err = filepath.Abs(path); err != nil {
return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err)
}
var watcher *fsnotify.Watcher
if watcher, err = fsnotify.NewWatcher(); err != nil {
return nil, err
}
entry := log.WithFields(map[string]any{"service": "watcher", "watcher": name})
if info.IsDir() {
service = &FileWatcherService{
watcher: watcher,
reload: reload,
log: entry,
directory: filepath.Clean(path),
}
} else {
service = &FileWatcherService{
watcher: watcher,
reload: reload,
log: entry,
directory: filepath.Dir(path),
file: filepath.Base(path),
}
}
if err = service.watcher.Add(service.directory); err != nil {
return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
}
return service, nil
}
// ProviderReload represents the required methods to support reloading a provider.
type ProviderReload interface {
Reload() (reloaded bool, err error)
}
// Service represents the required methods to support handling a service.
type Service interface {
Run() (err error)
Shutdown()
}
// ServerService is a Service which runs a webserver.
type ServerService struct {
server *fasthttp.Server
paths []string
isTLS bool
listener net.Listener
log *logrus.Entry
}
// Run the ServerService.
func (service *ServerService) Run() (err error) {
defer func() {
if r := recover(); r != nil {
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
}
}()
service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
if err = service.server.Serve(service.listener); err != nil {
service.log.WithError(err).Error("Error returned attempting to serve requests")
return err
}
return nil
}
// Shutdown the ServerService.
func (service *ServerService) Shutdown() {
if err := service.server.Shutdown(); err != nil {
service.log.WithError(err).Error("Error occurred during shutdown")
}
}
// FileWatcherService is a Service that watches files for changes.
type FileWatcherService struct {
watcher *fsnotify.Watcher
reload ProviderReload
log *logrus.Entry
file string
directory string
}
// Run the FileWatcherService.
func (service *FileWatcherService) Run() (err error) {
defer func() {
if r := recover(); r != nil {
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
}
}()
service.log.WithField("file", filepath.Join(service.directory, service.file)).Info("Watching for file changes to the file")
for {
select {
case event, ok := <-service.watcher.Events:
if !ok {
return nil
}
if service.file != "" && service.file != filepath.Base(event.Name) {
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Tracef("File modification detected to irrelevant file")
break
}
switch {
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File modification was detected")
var reloaded bool
switch reloaded, err = service.reload.Reload(); {
case err != nil:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).WithError(err).Error("Error occurred during reload")
case reloaded:
service.log.WithField("file", event.Name).Info("Reloaded successfully")
default:
service.log.WithField("file", event.Name).Debug("Reload of was triggered but it was skipped")
}
case event.Op&fsnotify.Remove == fsnotify.Remove:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File remove was detected")
}
case err, ok := <-service.watcher.Errors:
if !ok {
return nil
}
service.log.WithError(err).Errorf("Error while watching files")
}
}
}
// Shutdown the FileWatcherService.
func (service *FileWatcherService) Shutdown() {
if err := service.watcher.Close(); err != nil {
service.log.WithError(err).Error("Error occurred during shutdown")
}
}
func svcSvrMainFunc(ctx *CmdCtx) (service Service) {
switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.config, ctx.providers); {
case err != nil:
ctx.log.WithError(err).Fatal("Create Server Service (main) returned error")
case svr != nil && listener != nil:
service = NewServerService("main", svr, listener, paths, isTLS, ctx.log)
default:
ctx.log.Fatal("Create Server Service (main) failed")
}
return service
}
func svcSvrMetricsFunc(ctx *CmdCtx) (service Service) {
switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.config, ctx.providers); {
case err != nil:
ctx.log.WithError(err).Fatal("Create Server Service (metrics) returned error")
case svr != nil && listener != nil:
service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.log)
default:
ctx.log.Debug("Create Server Service (metrics) skipped")
}
return service
}
func svcWatcherUsersFunc(ctx *CmdCtx) (service Service) {
var err error
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
if service, err = NewFileWatcherService("users", ctx.config.AuthenticationBackend.File.Path, provider, ctx.log); err != nil {
ctx.log.WithError(err).Fatal("Create Watcher Service (users) returned error")
}
}
return service
}
func connectionType(isTLS bool) string {
if isTLS {
return "TLS"
}
return "non-TLS"
}
func servicesRun(ctx *CmdCtx) {
cctx, cancel := context.WithCancel(ctx)
group, cctx := errgroup.WithContext(cctx)
defer cancel()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
var (
services []Service
)
for _, serviceFunc := range []func(ctx *CmdCtx) Service{
svcSvrMainFunc, svcSvrMetricsFunc,
svcWatcherUsersFunc,
} {
if service := serviceFunc(ctx); service != nil {
services = append(services, service)
group.Go(service.Run)
}
}
ctx.log.Info("Startup Complete")
select {
case s := <-quit:
switch s {
case syscall.SIGINT:
ctx.log.WithField("signal", "SIGINT").Debugf("Shutdown started due to signal")
case syscall.SIGTERM:
ctx.log.WithField("signal", "SIGTERM").Debugf("Shutdown started due to signal")
}
case <-cctx.Done():
ctx.log.Debugf("Shutdown started due to context completion")
}
cancel()
ctx.log.Infof("Shutting down")
wgShutdown := &sync.WaitGroup{}
for _, service := range services {
go func() {
service.Shutdown()
wgShutdown.Done()
}()
wgShutdown.Add(1)
}
wgShutdown.Wait()
var err error
if err = ctx.providers.StorageProvider.Close(); err != nil {
ctx.log.WithError(err).Error("Error occurred closing database connections")
}
if err = group.Wait(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
}
}

View File

@ -83,12 +83,3 @@ const (
tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'" tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'"
tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'" tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'"
) )
const (
connNonTLS = "non-TLS"
connTLS = "TLS"
)
const (
fmtLogServerInit = "Initializing %s for %s connections on '%s' path '%s'"
)

View File

@ -92,10 +92,10 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
} }
//nolint:gocyclo //nolint:gocyclo
func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
log := logging.Logger() log := logging.Logger()
optsTemplatedFile := NewTemplatedFileOptions(&config) optsTemplatedFile := NewTemplatedFileOptions(config)
serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile) serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile)
serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile) serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile)
@ -104,7 +104,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
handlerPublicHTML := newPublicHTMLEmbeddedHandler() handlerPublicHTML := newPublicHTMLEmbeddedHandler()
handlerLocales := newLocalesEmbeddedHandler() handlerLocales := newLocalesEmbeddedHandler()
bridge := middlewares.NewBridgeBuilder(config, providers). bridge := middlewares.NewBridgeBuilder(*config, providers).
WithPreMiddlewares(middlewares.SecurityHeaders).Build() WithPreMiddlewares(middlewares.SecurityHeaders).Build()
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
@ -141,11 +141,11 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
r.GET("/api/"+file, handlerPublicHTML) r.GET("/api/"+file, handlerPublicHTML)
} }
middlewareAPI := middlewares.NewBridgeBuilder(config, providers). middlewareAPI := middlewares.NewBridgeBuilder(*config, providers).
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
Build() Build()
middleware1FA := middlewares.NewBridgeBuilder(config, providers). middleware1FA := middlewares.NewBridgeBuilder(*config, providers).
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
WithPostMiddlewares(middlewares.Require1FA). WithPostMiddlewares(middlewares.Require1FA).
Build() Build()
@ -162,7 +162,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
for name, endpoint := range config.Server.Endpoints.Authz { for name, endpoint := range config.Server.Endpoints.Authz {
uri := path.Join(pathAuthz, name) uri := path.Join(pathAuthz, name)
authz := handlers.NewAuthzBuilder().WithConfig(&config).WithEndpointConfig(endpoint).Build() authz := handlers.NewAuthzBuilder().WithConfig(config).WithEndpointConfig(endpoint).Build()
handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler)) handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler))
@ -268,7 +268,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
} }
if providers.OpenIDConnect != nil { if providers.OpenIDConnect != nil {
bridgeOIDC := middlewares.NewBridgeBuilder(config, providers).WithPreMiddlewares( bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares(
middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
).Build() ).Build()

View File

@ -7,7 +7,6 @@ import (
"net" "net"
"os" "os"
"strconv" "strconv"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -18,9 +17,9 @@ import (
) )
// CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers. // CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers.
func CreateDefaultServer(config schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, err error) { func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) {
if err = providers.Templates.LoadTemplatedAssets(assets); err != nil { if err = providers.Templates.LoadTemplatedAssets(assets); err != nil {
return nil, nil, fmt.Errorf("failed to load templated assets: %w", err) return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err)
} }
server = &fasthttp.Server{ server = &fasthttp.Server{
@ -38,15 +37,14 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port)) address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port))
var ( var (
connectionType string
connectionScheme string connectionScheme string
) )
if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" { if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" {
connectionType, connectionScheme = connTLS, schemeHTTPS isTLS, connectionScheme = true, schemeHTTPS
if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil { if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil {
return nil, nil, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err) return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
} }
if len(config.Server.TLS.ClientCertificates) > 0 { if len(config.Server.TLS.ClientCertificates) > 0 {
@ -56,7 +54,7 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
for _, path := range config.Server.TLS.ClientCertificates { for _, path := range config.Server.TLS.ClientCertificates {
if cert, err = os.ReadFile(path); err != nil { if cert, err = os.ReadFile(path); err != nil {
return nil, nil, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err) return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
} }
caCertPool.AppendCertsFromPEM(cert) caCertPool.AppendCertsFromPEM(cert)
@ -69,51 +67,51 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
} }
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil { if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err) return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
} }
} else { } else {
connectionType, connectionScheme = connNonTLS, schemeHTTP connectionScheme = schemeHTTP
if listener, err = net.Listen("tcp", address); err != nil { if listener, err = net.Listen("tcp", address); err != nil {
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err) return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
} }
} }
if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host, if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host,
config.Server.Path, config.Server.Port); err != nil { config.Server.Path, config.Server.Port); err != nil {
return nil, nil, fmt.Errorf("unable to configure healthcheck: %w", err) return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err)
} }
paths := []string{"/"} paths = []string{"/"}
if config.Server.Path != "" { if config.Server.Path != "" {
paths = append(paths, config.Server.Path) paths = append(paths, config.Server.Path)
} }
logging.Logger().Infof(fmtLogServerInit, "server", connectionType, listener.Addr().String(), strings.Join(paths, "' and '")) return server, listener, paths, isTLS, nil
return server, listener, nil
} }
// CreateMetricsServer creates a metrics server. // CreateMetricsServer creates a metrics server.
func CreateMetricsServer(config schema.TelemetryMetricsConfig) (server *fasthttp.Server, listener net.Listener, err error) { func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) {
if listener, err = config.Address.Listener(); err != nil { if providers.Metrics == nil {
return nil, nil, err return
} }
server = &fasthttp.Server{ server = &fasthttp.Server{
ErrorHandler: handleError(), ErrorHandler: handleError(),
NoDefaultServerHeader: true, NoDefaultServerHeader: true,
Handler: handleMetrics(), Handler: handleMetrics(),
ReadBufferSize: config.Buffers.Read, ReadBufferSize: config.Telemetry.Metrics.Buffers.Read,
WriteBufferSize: config.Buffers.Write, WriteBufferSize: config.Telemetry.Metrics.Buffers.Write,
ReadTimeout: config.Timeouts.Read, ReadTimeout: config.Telemetry.Metrics.Timeouts.Read,
WriteTimeout: config.Timeouts.Write, WriteTimeout: config.Telemetry.Metrics.Timeouts.Write,
IdleTimeout: config.Timeouts.Idle, IdleTimeout: config.Telemetry.Metrics.Timeouts.Idle,
Logger: logging.LoggerPrintf(logrus.DebugLevel), Logger: logging.LoggerPrintf(logrus.DebugLevel),
} }
logging.Logger().Infof(fmtLogServerInit, "server (metrics)", connNonTLS, listener.Addr().String(), "/metrics") if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil {
return nil, nil, nil, false, err
}
return server, listener, nil return server, listener, []string{"/metrics"}, false, nil
} }

View File

@ -152,7 +152,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS
return nil, err return nil, err
} }
s, listener, err := CreateDefaultServer(configuration, providers) s, listener, _, _, err := CreateDefaultServer(&configuration, providers)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -42,7 +42,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error
90*time.Second, 90*time.Second,
dockerEnvironment, dockerEnvironment,
"authelia-backend", "authelia-backend",
[]string{"Initializing server for"}) []string{"Startup Complete"})
} }
func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error { func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {