From 2888ee7f41ed69d94807cef681ec0158f2e66241 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 11 Feb 2023 21:45:26 +1100 Subject: [PATCH] refactor(commands): services (#4914) Misc refactoring of the services logic to simplify the --- internal/commands/const.go | 4 + internal/commands/context.go | 10 -- internal/commands/root.go | 194 +------------------- internal/commands/services.go | 311 +++++++++++++++++++++++++++++++++ internal/server/const.go | 9 - internal/server/handlers.go | 14 +- internal/server/server.go | 48 +++-- internal/server/server_test.go | 2 +- internal/suites/environment.go | 2 +- 9 files changed, 348 insertions(+), 246 deletions(-) create mode 100644 internal/commands/services.go diff --git a/internal/commands/const.go b/internal/commands/const.go index ea56d2f5b..2e322a2c4 100644 --- a/internal/commands/const.go +++ b/internal/commands/const.go @@ -772,3 +772,7 @@ Layouts: ANSIC: Mon Jan _2 15:04:05 2006 Date: 2006-01-02` ) + +const ( + fmtLogServerListening = "Server is listening for %s connections on '%s' path '%s'" +) diff --git a/internal/commands/context.go b/internal/commands/context.go index 9a9e367d1..64282acb6 100644 --- a/internal/commands/context.go +++ b/internal/commands/context.go @@ -9,7 +9,6 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/pflag" - "golang.org/x/sync/errgroup" "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authorization" @@ -35,14 +34,8 @@ import ( func NewCmdCtx() *CmdCtx { ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - - group, ctx := errgroup.WithContext(ctx) - return &CmdCtx{ Context: ctx, - cancel: cancel, - group: group, log: logging.Logger(), providers: middlewares.Providers{ Random: &random.Cryptographical{}, @@ -55,9 +48,6 @@ func NewCmdCtx() *CmdCtx { type CmdCtx struct { context.Context - cancel context.CancelFunc - group *errgroup.Group - log *logrus.Logger config *schema.Configuration diff --git a/internal/commands/root.go b/internal/commands/root.go index 7b728eb5a..7d57095b5 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -2,21 +2,13 @@ package commands import ( "fmt" - "net" "os" - "os/signal" - "path/filepath" "strings" - "syscall" - "github.com/fsnotify/fsnotify" "github.com/spf13/cobra" - "github.com/valyala/fasthttp" - "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/model" - "github.com/authelia/authelia/v4/internal/server" "github.com/authelia/authelia/v4/internal/utils" ) @@ -95,195 +87,11 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) { ctx.cconfig = nil - runServices(ctx) + servicesRun(ctx) return nil } -//nolint:gocyclo // Complexity is required in this function. -func runServices(ctx *CmdCtx) { - defer ctx.cancel() - - quit := make(chan os.Signal, 1) - - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - - defer signal.Stop(quit) - - var ( - mainServer, metricsServer *fasthttp.Server - mainListener, metricsListener net.Listener - ) - - ctx.group.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - ctx.log.WithError(recoverErr(r)).Errorf("Server (main) critical error caught (recovered)") - } - }() - - if mainServer, mainListener, err = server.CreateDefaultServer(*ctx.config, ctx.providers); err != nil { - ctx.log.WithError(err).Error("Create Server (main) returned error") - - return err - } - - if err = mainServer.Serve(mainListener); err != nil { - ctx.log.WithError(err).Error("Server (main) returned error") - - return err - } - - return nil - }) - - ctx.group.Go(func() (err error) { - if ctx.providers.Metrics == nil { - return nil - } - - defer func() { - if r := recover(); r != nil { - ctx.log.WithError(recoverErr(r)).Errorf("Server (metrics) critical error caught (recovered)") - } - }() - - if metricsServer, metricsListener, err = server.CreateMetricsServer(ctx.config.Telemetry.Metrics); err != nil { - ctx.log.WithError(err).Error("Create Server (metrics) returned error") - - return err - } - - if err = metricsServer.Serve(metricsListener); err != nil { - ctx.log.WithError(err).Error("Server (metrics) returned error") - - return err - } - - return nil - }) - - if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch { - provider := ctx.providers.UserProvider.(*authentication.FileUserProvider) - if watcher, err := runServiceFileWatcher(ctx, ctx.config.AuthenticationBackend.File.Path, provider); err != nil { - ctx.log.WithError(err).Errorf("File Watcher (user database) start returned error") - } else { - defer func(watcher *fsnotify.Watcher) { - if err := watcher.Close(); err != nil { - ctx.log.WithError(err).Errorf("File Watcher (user database) close returned error") - } - }(watcher) - } - } - - select { - case s := <-quit: - switch s { - case syscall.SIGINT: - ctx.log.Debugf("Shutdown started due to SIGINT") - case syscall.SIGQUIT: - ctx.log.Debugf("Shutdown started due to SIGQUIT") - } - case <-ctx.Done(): - ctx.log.Debugf("Shutdown started due to context completion") - } - - ctx.cancel() - - ctx.log.Infof("Shutting down") - - var err error - - if mainServer != nil { - if err = mainServer.Shutdown(); err != nil { - ctx.log.WithError(err).Errorf("Error occurred shutting down the server") - } - } - - if metricsServer != nil { - if err = metricsServer.Shutdown(); err != nil { - ctx.log.WithError(err).Errorf("Error occurred shutting down the metrics server") - } - } - - if err = ctx.providers.StorageProvider.Close(); err != nil { - ctx.log.WithError(err).Errorf("Error occurred closing the database connection") - } - - if err = ctx.group.Wait(); err != nil { - ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown") - } -} - -type ReloadFilter func(path string) (skipped bool) - -type ProviderReload interface { - Reload() (reloaded bool, err error) -} - -func runServiceFileWatcher(ctx *CmdCtx, path string, reload ProviderReload) (watcher *fsnotify.Watcher, err error) { - if watcher, err = fsnotify.NewWatcher(); err != nil { - return nil, err - } - - failed := make(chan struct{}) - - var directory, filename string - - if path != "" { - directory, filename = filepath.Dir(path), filepath.Base(path) - } - - ctx.group.Go(func() error { - for { - select { - case <-failed: - return nil - case event, ok := <-watcher.Events: - if !ok { - return nil - } - - if filename != filepath.Base(event.Name) { - ctx.log.WithField("file", event.Name).WithField("op", event.Op).Tracef("File modification detected to irrelevant file") - break - } - - switch { - case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create: - ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("File modification detected") - - switch reloaded, err := reload.Reload(); { - case err != nil: - ctx.log.WithField("file", event.Name).WithField("op", event.Op).WithError(err).Error("Error occurred reloading file") - case reloaded: - ctx.log.WithField("file", event.Name).Info("Reloaded file successfully") - default: - ctx.log.WithField("file", event.Name).Debug("Reload of file was triggered but it was skipped") - } - case event.Op&fsnotify.Remove == fsnotify.Remove: - ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("Remove of file was detected") - } - case err, ok := <-watcher.Errors: - if !ok { - return nil - } - ctx.log.WithError(err).Errorf("Error while watching files") - } - } - }) - - if err := watcher.Add(directory); err != nil { - failed <- struct{}{} - - return nil, err - } - - ctx.log.WithField("directory", directory).WithField("file", filename).Debug("Directory is being watched for changes to the file") - - return watcher, nil -} - func doStartupChecks(ctx *CmdCtx) { var ( failures []string diff --git a/internal/commands/services.go b/internal/commands/services.go new file mode 100644 index 000000000..3d5223ed8 --- /dev/null +++ b/internal/commands/services.go @@ -0,0 +1,311 @@ +package commands + +import ( + "context" + "fmt" + "net" + "os" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + "github.com/valyala/fasthttp" + "golang.org/x/sync/errgroup" + + "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/server" +) + +// NewServerService creates a new ServerService with the appropriate logger etc. +func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) { + return &ServerService{ + server: server, + listener: listener, + paths: paths, + isTLS: isTLS, + log: log.WithFields(map[string]any{"service": "server", "server": name}), + } +} + +// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc. +func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) { + if path == "" { + return nil, fmt.Errorf("path must be specified") + } + + var info os.FileInfo + + if info, err = os.Stat(path); err != nil { + return nil, fmt.Errorf("error stating file '%s': %w", path, err) + } + + if path, err = filepath.Abs(path); err != nil { + return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err) + } + + var watcher *fsnotify.Watcher + + if watcher, err = fsnotify.NewWatcher(); err != nil { + return nil, err + } + + entry := log.WithFields(map[string]any{"service": "watcher", "watcher": name}) + + if info.IsDir() { + service = &FileWatcherService{ + watcher: watcher, + reload: reload, + log: entry, + directory: filepath.Clean(path), + } + } else { + service = &FileWatcherService{ + watcher: watcher, + reload: reload, + log: entry, + directory: filepath.Dir(path), + file: filepath.Base(path), + } + } + + if err = service.watcher.Add(service.directory); err != nil { + return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err) + } + + return service, nil +} + +// ProviderReload represents the required methods to support reloading a provider. +type ProviderReload interface { + Reload() (reloaded bool, err error) +} + +// Service represents the required methods to support handling a service. +type Service interface { + Run() (err error) + Shutdown() +} + +// ServerService is a Service which runs a webserver. +type ServerService struct { + server *fasthttp.Server + paths []string + isTLS bool + listener net.Listener + log *logrus.Entry +} + +// Run the ServerService. +func (service *ServerService) Run() (err error) { + defer func() { + if r := recover(); r != nil { + service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") + } + }() + + service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '")) + + if err = service.server.Serve(service.listener); err != nil { + service.log.WithError(err).Error("Error returned attempting to serve requests") + + return err + } + + return nil +} + +// Shutdown the ServerService. +func (service *ServerService) Shutdown() { + if err := service.server.Shutdown(); err != nil { + service.log.WithError(err).Error("Error occurred during shutdown") + } +} + +// FileWatcherService is a Service that watches files for changes. +type FileWatcherService struct { + watcher *fsnotify.Watcher + reload ProviderReload + + log *logrus.Entry + file string + directory string +} + +// Run the FileWatcherService. +func (service *FileWatcherService) Run() (err error) { + defer func() { + if r := recover(); r != nil { + service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)") + } + }() + + service.log.WithField("file", filepath.Join(service.directory, service.file)).Info("Watching for file changes to the file") + + for { + select { + case event, ok := <-service.watcher.Events: + if !ok { + return nil + } + + if service.file != "" && service.file != filepath.Base(event.Name) { + service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Tracef("File modification detected to irrelevant file") + break + } + + switch { + case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create: + service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File modification was detected") + + var reloaded bool + + switch reloaded, err = service.reload.Reload(); { + case err != nil: + service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).WithError(err).Error("Error occurred during reload") + case reloaded: + service.log.WithField("file", event.Name).Info("Reloaded successfully") + default: + service.log.WithField("file", event.Name).Debug("Reload of was triggered but it was skipped") + } + case event.Op&fsnotify.Remove == fsnotify.Remove: + service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File remove was detected") + } + case err, ok := <-service.watcher.Errors: + if !ok { + return nil + } + + service.log.WithError(err).Errorf("Error while watching files") + } + } +} + +// Shutdown the FileWatcherService. +func (service *FileWatcherService) Shutdown() { + if err := service.watcher.Close(); err != nil { + service.log.WithError(err).Error("Error occurred during shutdown") + } +} + +func svcSvrMainFunc(ctx *CmdCtx) (service Service) { + switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.config, ctx.providers); { + case err != nil: + ctx.log.WithError(err).Fatal("Create Server Service (main) returned error") + case svr != nil && listener != nil: + service = NewServerService("main", svr, listener, paths, isTLS, ctx.log) + default: + ctx.log.Fatal("Create Server Service (main) failed") + } + + return service +} + +func svcSvrMetricsFunc(ctx *CmdCtx) (service Service) { + switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.config, ctx.providers); { + case err != nil: + ctx.log.WithError(err).Fatal("Create Server Service (metrics) returned error") + case svr != nil && listener != nil: + service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.log) + default: + ctx.log.Debug("Create Server Service (metrics) skipped") + } + + return service +} + +func svcWatcherUsersFunc(ctx *CmdCtx) (service Service) { + var err error + + if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch { + provider := ctx.providers.UserProvider.(*authentication.FileUserProvider) + + if service, err = NewFileWatcherService("users", ctx.config.AuthenticationBackend.File.Path, provider, ctx.log); err != nil { + ctx.log.WithError(err).Fatal("Create Watcher Service (users) returned error") + } + } + + return service +} + +func connectionType(isTLS bool) string { + if isTLS { + return "TLS" + } + + return "non-TLS" +} + +func servicesRun(ctx *CmdCtx) { + cctx, cancel := context.WithCancel(ctx) + + group, cctx := errgroup.WithContext(cctx) + + defer cancel() + + quit := make(chan os.Signal, 1) + + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + defer signal.Stop(quit) + + var ( + services []Service + ) + + for _, serviceFunc := range []func(ctx *CmdCtx) Service{ + svcSvrMainFunc, svcSvrMetricsFunc, + svcWatcherUsersFunc, + } { + if service := serviceFunc(ctx); service != nil { + services = append(services, service) + + group.Go(service.Run) + } + } + + ctx.log.Info("Startup Complete") + + select { + case s := <-quit: + switch s { + case syscall.SIGINT: + ctx.log.WithField("signal", "SIGINT").Debugf("Shutdown started due to signal") + case syscall.SIGTERM: + ctx.log.WithField("signal", "SIGTERM").Debugf("Shutdown started due to signal") + } + case <-cctx.Done(): + ctx.log.Debugf("Shutdown started due to context completion") + } + + cancel() + + ctx.log.Infof("Shutting down") + + wgShutdown := &sync.WaitGroup{} + + for _, service := range services { + go func() { + service.Shutdown() + + wgShutdown.Done() + }() + + wgShutdown.Add(1) + } + + wgShutdown.Wait() + + var err error + + if err = ctx.providers.StorageProvider.Close(); err != nil { + ctx.log.WithError(err).Error("Error occurred closing database connections") + } + + if err = group.Wait(); err != nil { + ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown") + } +} diff --git a/internal/server/const.go b/internal/server/const.go index 069797b65..abbc52206 100644 --- a/internal/server/const.go +++ b/internal/server/const.go @@ -83,12 +83,3 @@ const ( tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'" tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'" ) - -const ( - connNonTLS = "non-TLS" - connTLS = "TLS" -) - -const ( - fmtLogServerInit = "Initializing %s for %s connections on '%s' path '%s'" -) diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 106fea7fd..8bd7717a8 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -92,10 +92,10 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler { } //nolint:gocyclo -func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { +func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { log := logging.Logger() - optsTemplatedFile := NewTemplatedFileOptions(&config) + optsTemplatedFile := NewTemplatedFileOptions(config) serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile) serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile) @@ -104,7 +104,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers) handlerPublicHTML := newPublicHTMLEmbeddedHandler() handlerLocales := newLocalesEmbeddedHandler() - bridge := middlewares.NewBridgeBuilder(config, providers). + bridge := middlewares.NewBridgeBuilder(*config, providers). WithPreMiddlewares(middlewares.SecurityHeaders).Build() policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). @@ -141,11 +141,11 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers) r.GET("/api/"+file, handlerPublicHTML) } - middlewareAPI := middlewares.NewBridgeBuilder(config, providers). + middlewareAPI := middlewares.NewBridgeBuilder(*config, providers). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). Build() - middleware1FA := middlewares.NewBridgeBuilder(config, providers). + middleware1FA := middlewares.NewBridgeBuilder(*config, providers). WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone). WithPostMiddlewares(middlewares.Require1FA). Build() @@ -162,7 +162,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers) for name, endpoint := range config.Server.Endpoints.Authz { uri := path.Join(pathAuthz, name) - authz := handlers.NewAuthzBuilder().WithConfig(&config).WithEndpointConfig(endpoint).Build() + authz := handlers.NewAuthzBuilder().WithConfig(config).WithEndpointConfig(endpoint).Build() handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler)) @@ -268,7 +268,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers) } if providers.OpenIDConnect != nil { - bridgeOIDC := middlewares.NewBridgeBuilder(config, providers).WithPreMiddlewares( + bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares( middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore, ).Build() diff --git a/internal/server/server.go b/internal/server/server.go index 319c708bc..1841cfaa7 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,7 +7,6 @@ import ( "net" "os" "strconv" - "strings" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" @@ -18,9 +17,9 @@ import ( ) // CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers. -func CreateDefaultServer(config schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, err error) { +func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) { if err = providers.Templates.LoadTemplatedAssets(assets); err != nil { - return nil, nil, fmt.Errorf("failed to load templated assets: %w", err) + return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err) } server = &fasthttp.Server{ @@ -38,15 +37,14 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port)) var ( - connectionType string connectionScheme string ) if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" { - connectionType, connectionScheme = connTLS, schemeHTTPS + isTLS, connectionScheme = true, schemeHTTPS if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil { - return nil, nil, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err) + return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err) } if len(config.Server.TLS.ClientCertificates) > 0 { @@ -56,7 +54,7 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov for _, path := range config.Server.TLS.ClientCertificates { if cert, err = os.ReadFile(path); err != nil { - return nil, nil, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err) + return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err) } caCertPool.AppendCertsFromPEM(cert) @@ -69,51 +67,51 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov } if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil { - return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err) + return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err) } } else { - connectionType, connectionScheme = connNonTLS, schemeHTTP + connectionScheme = schemeHTTP if listener, err = net.Listen("tcp", address); err != nil { - return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err) + return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err) } } if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host, config.Server.Path, config.Server.Port); err != nil { - return nil, nil, fmt.Errorf("unable to configure healthcheck: %w", err) + return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err) } - paths := []string{"/"} + paths = []string{"/"} if config.Server.Path != "" { paths = append(paths, config.Server.Path) } - logging.Logger().Infof(fmtLogServerInit, "server", connectionType, listener.Addr().String(), strings.Join(paths, "' and '")) - - return server, listener, nil + return server, listener, paths, isTLS, nil } // CreateMetricsServer creates a metrics server. -func CreateMetricsServer(config schema.TelemetryMetricsConfig) (server *fasthttp.Server, listener net.Listener, err error) { - if listener, err = config.Address.Listener(); err != nil { - return nil, nil, err +func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) { + if providers.Metrics == nil { + return } server = &fasthttp.Server{ ErrorHandler: handleError(), NoDefaultServerHeader: true, Handler: handleMetrics(), - ReadBufferSize: config.Buffers.Read, - WriteBufferSize: config.Buffers.Write, - ReadTimeout: config.Timeouts.Read, - WriteTimeout: config.Timeouts.Write, - IdleTimeout: config.Timeouts.Idle, + ReadBufferSize: config.Telemetry.Metrics.Buffers.Read, + WriteBufferSize: config.Telemetry.Metrics.Buffers.Write, + ReadTimeout: config.Telemetry.Metrics.Timeouts.Read, + WriteTimeout: config.Telemetry.Metrics.Timeouts.Write, + IdleTimeout: config.Telemetry.Metrics.Timeouts.Idle, Logger: logging.LoggerPrintf(logrus.DebugLevel), } - logging.Logger().Infof(fmtLogServerInit, "server (metrics)", connNonTLS, listener.Addr().String(), "/metrics") + if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil { + return nil, nil, nil, false, err + } - return server, listener, nil + return server, listener, []string{"/metrics"}, false, nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 9b6dfa992..1712ae8c8 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -152,7 +152,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS return nil, err } - s, listener, err := CreateDefaultServer(configuration, providers) + s, listener, _, _, err := CreateDefaultServer(&configuration, providers) if err != nil { return nil, err diff --git a/internal/suites/environment.go b/internal/suites/environment.go index 80f5ac260..0c79b898d 100644 --- a/internal/suites/environment.go +++ b/internal/suites/environment.go @@ -42,7 +42,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error 90*time.Second, dockerEnvironment, "authelia-backend", - []string{"Initializing server for"}) + []string{"Startup Complete"}) } func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {