refactor(commands): services (#4914)
Misc refactoring of the services logic to simplify thepull/4916/head
parent
1a5178a8a5
commit
2888ee7f41
|
@ -772,3 +772,7 @@ Layouts:
|
|||
ANSIC: Mon Jan _2 15:04:05 2006
|
||||
Date: 2006-01-02`
|
||||
)
|
||||
|
||||
const (
|
||||
fmtLogServerListening = "Server is listening for %s connections on '%s' path '%s'"
|
||||
)
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/authorization"
|
||||
|
@ -35,14 +34,8 @@ import (
|
|||
func NewCmdCtx() *CmdCtx {
|
||||
ctx := context.Background()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
return &CmdCtx{
|
||||
Context: ctx,
|
||||
cancel: cancel,
|
||||
group: group,
|
||||
log: logging.Logger(),
|
||||
providers: middlewares.Providers{
|
||||
Random: &random.Cryptographical{},
|
||||
|
@ -55,9 +48,6 @@ func NewCmdCtx() *CmdCtx {
|
|||
type CmdCtx struct {
|
||||
context.Context
|
||||
|
||||
cancel context.CancelFunc
|
||||
group *errgroup.Group
|
||||
|
||||
log *logrus.Logger
|
||||
|
||||
config *schema.Configuration
|
||||
|
|
|
@ -2,21 +2,13 @@ package commands
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/model"
|
||||
"github.com/authelia/authelia/v4/internal/server"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -95,195 +87,11 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) {
|
|||
|
||||
ctx.cconfig = nil
|
||||
|
||||
runServices(ctx)
|
||||
servicesRun(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocyclo // Complexity is required in this function.
|
||||
func runServices(ctx *CmdCtx) {
|
||||
defer ctx.cancel()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
defer signal.Stop(quit)
|
||||
|
||||
var (
|
||||
mainServer, metricsServer *fasthttp.Server
|
||||
mainListener, metricsListener net.Listener
|
||||
)
|
||||
|
||||
ctx.group.Go(func() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ctx.log.WithError(recoverErr(r)).Errorf("Server (main) critical error caught (recovered)")
|
||||
}
|
||||
}()
|
||||
|
||||
if mainServer, mainListener, err = server.CreateDefaultServer(*ctx.config, ctx.providers); err != nil {
|
||||
ctx.log.WithError(err).Error("Create Server (main) returned error")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err = mainServer.Serve(mainListener); err != nil {
|
||||
ctx.log.WithError(err).Error("Server (main) returned error")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx.group.Go(func() (err error) {
|
||||
if ctx.providers.Metrics == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ctx.log.WithError(recoverErr(r)).Errorf("Server (metrics) critical error caught (recovered)")
|
||||
}
|
||||
}()
|
||||
|
||||
if metricsServer, metricsListener, err = server.CreateMetricsServer(ctx.config.Telemetry.Metrics); err != nil {
|
||||
ctx.log.WithError(err).Error("Create Server (metrics) returned error")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err = metricsServer.Serve(metricsListener); err != nil {
|
||||
ctx.log.WithError(err).Error("Server (metrics) returned error")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
|
||||
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
|
||||
if watcher, err := runServiceFileWatcher(ctx, ctx.config.AuthenticationBackend.File.Path, provider); err != nil {
|
||||
ctx.log.WithError(err).Errorf("File Watcher (user database) start returned error")
|
||||
} else {
|
||||
defer func(watcher *fsnotify.Watcher) {
|
||||
if err := watcher.Close(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("File Watcher (user database) close returned error")
|
||||
}
|
||||
}(watcher)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case s := <-quit:
|
||||
switch s {
|
||||
case syscall.SIGINT:
|
||||
ctx.log.Debugf("Shutdown started due to SIGINT")
|
||||
case syscall.SIGQUIT:
|
||||
ctx.log.Debugf("Shutdown started due to SIGQUIT")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
ctx.log.Debugf("Shutdown started due to context completion")
|
||||
}
|
||||
|
||||
ctx.cancel()
|
||||
|
||||
ctx.log.Infof("Shutting down")
|
||||
|
||||
var err error
|
||||
|
||||
if mainServer != nil {
|
||||
if err = mainServer.Shutdown(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("Error occurred shutting down the server")
|
||||
}
|
||||
}
|
||||
|
||||
if metricsServer != nil {
|
||||
if err = metricsServer.Shutdown(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("Error occurred shutting down the metrics server")
|
||||
}
|
||||
}
|
||||
|
||||
if err = ctx.providers.StorageProvider.Close(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("Error occurred closing the database connection")
|
||||
}
|
||||
|
||||
if err = ctx.group.Wait(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
type ReloadFilter func(path string) (skipped bool)
|
||||
|
||||
type ProviderReload interface {
|
||||
Reload() (reloaded bool, err error)
|
||||
}
|
||||
|
||||
func runServiceFileWatcher(ctx *CmdCtx, path string, reload ProviderReload) (watcher *fsnotify.Watcher, err error) {
|
||||
if watcher, err = fsnotify.NewWatcher(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
failed := make(chan struct{})
|
||||
|
||||
var directory, filename string
|
||||
|
||||
if path != "" {
|
||||
directory, filename = filepath.Dir(path), filepath.Base(path)
|
||||
}
|
||||
|
||||
ctx.group.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-failed:
|
||||
return nil
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if filename != filepath.Base(event.Name) {
|
||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Tracef("File modification detected to irrelevant file")
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
|
||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("File modification detected")
|
||||
|
||||
switch reloaded, err := reload.Reload(); {
|
||||
case err != nil:
|
||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).WithError(err).Error("Error occurred reloading file")
|
||||
case reloaded:
|
||||
ctx.log.WithField("file", event.Name).Info("Reloaded file successfully")
|
||||
default:
|
||||
ctx.log.WithField("file", event.Name).Debug("Reload of file was triggered but it was skipped")
|
||||
}
|
||||
case event.Op&fsnotify.Remove == fsnotify.Remove:
|
||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("Remove of file was detected")
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ctx.log.WithError(err).Errorf("Error while watching files")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err := watcher.Add(directory); err != nil {
|
||||
failed <- struct{}{}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.log.WithField("directory", directory).WithField("file", filename).Debug("Directory is being watched for changes to the file")
|
||||
|
||||
return watcher, nil
|
||||
}
|
||||
|
||||
func doStartupChecks(ctx *CmdCtx) {
|
||||
var (
|
||||
failures []string
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/server"
|
||||
)
|
||||
|
||||
// NewServerService creates a new ServerService with the appropriate logger etc.
|
||||
func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) {
|
||||
return &ServerService{
|
||||
server: server,
|
||||
listener: listener,
|
||||
paths: paths,
|
||||
isTLS: isTLS,
|
||||
log: log.WithFields(map[string]any{"service": "server", "server": name}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc.
|
||||
func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("path must be specified")
|
||||
}
|
||||
|
||||
var info os.FileInfo
|
||||
|
||||
if info, err = os.Stat(path); err != nil {
|
||||
return nil, fmt.Errorf("error stating file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
if path, err = filepath.Abs(path); err != nil {
|
||||
return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err)
|
||||
}
|
||||
|
||||
var watcher *fsnotify.Watcher
|
||||
|
||||
if watcher, err = fsnotify.NewWatcher(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry := log.WithFields(map[string]any{"service": "watcher", "watcher": name})
|
||||
|
||||
if info.IsDir() {
|
||||
service = &FileWatcherService{
|
||||
watcher: watcher,
|
||||
reload: reload,
|
||||
log: entry,
|
||||
directory: filepath.Clean(path),
|
||||
}
|
||||
} else {
|
||||
service = &FileWatcherService{
|
||||
watcher: watcher,
|
||||
reload: reload,
|
||||
log: entry,
|
||||
directory: filepath.Dir(path),
|
||||
file: filepath.Base(path),
|
||||
}
|
||||
}
|
||||
|
||||
if err = service.watcher.Add(service.directory); err != nil {
|
||||
return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// ProviderReload represents the required methods to support reloading a provider.
|
||||
type ProviderReload interface {
|
||||
Reload() (reloaded bool, err error)
|
||||
}
|
||||
|
||||
// Service represents the required methods to support handling a service.
|
||||
type Service interface {
|
||||
Run() (err error)
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// ServerService is a Service which runs a webserver.
|
||||
type ServerService struct {
|
||||
server *fasthttp.Server
|
||||
paths []string
|
||||
isTLS bool
|
||||
listener net.Listener
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
// Run the ServerService.
|
||||
func (service *ServerService) Run() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
|
||||
}
|
||||
}()
|
||||
|
||||
service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
|
||||
|
||||
if err = service.server.Serve(service.listener); err != nil {
|
||||
service.log.WithError(err).Error("Error returned attempting to serve requests")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown the ServerService.
|
||||
func (service *ServerService) Shutdown() {
|
||||
if err := service.server.Shutdown(); err != nil {
|
||||
service.log.WithError(err).Error("Error occurred during shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
// FileWatcherService is a Service that watches files for changes.
|
||||
type FileWatcherService struct {
|
||||
watcher *fsnotify.Watcher
|
||||
reload ProviderReload
|
||||
|
||||
log *logrus.Entry
|
||||
file string
|
||||
directory string
|
||||
}
|
||||
|
||||
// Run the FileWatcherService.
|
||||
func (service *FileWatcherService) Run() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
|
||||
}
|
||||
}()
|
||||
|
||||
service.log.WithField("file", filepath.Join(service.directory, service.file)).Info("Watching for file changes to the file")
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-service.watcher.Events:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if service.file != "" && service.file != filepath.Base(event.Name) {
|
||||
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Tracef("File modification detected to irrelevant file")
|
||||
break
|
||||
}
|
||||
|
||||
switch {
|
||||
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
|
||||
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File modification was detected")
|
||||
|
||||
var reloaded bool
|
||||
|
||||
switch reloaded, err = service.reload.Reload(); {
|
||||
case err != nil:
|
||||
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).WithError(err).Error("Error occurred during reload")
|
||||
case reloaded:
|
||||
service.log.WithField("file", event.Name).Info("Reloaded successfully")
|
||||
default:
|
||||
service.log.WithField("file", event.Name).Debug("Reload of was triggered but it was skipped")
|
||||
}
|
||||
case event.Op&fsnotify.Remove == fsnotify.Remove:
|
||||
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File remove was detected")
|
||||
}
|
||||
case err, ok := <-service.watcher.Errors:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
service.log.WithError(err).Errorf("Error while watching files")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the FileWatcherService.
|
||||
func (service *FileWatcherService) Shutdown() {
|
||||
if err := service.watcher.Close(); err != nil {
|
||||
service.log.WithError(err).Error("Error occurred during shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func svcSvrMainFunc(ctx *CmdCtx) (service Service) {
|
||||
switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.config, ctx.providers); {
|
||||
case err != nil:
|
||||
ctx.log.WithError(err).Fatal("Create Server Service (main) returned error")
|
||||
case svr != nil && listener != nil:
|
||||
service = NewServerService("main", svr, listener, paths, isTLS, ctx.log)
|
||||
default:
|
||||
ctx.log.Fatal("Create Server Service (main) failed")
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func svcSvrMetricsFunc(ctx *CmdCtx) (service Service) {
|
||||
switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.config, ctx.providers); {
|
||||
case err != nil:
|
||||
ctx.log.WithError(err).Fatal("Create Server Service (metrics) returned error")
|
||||
case svr != nil && listener != nil:
|
||||
service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.log)
|
||||
default:
|
||||
ctx.log.Debug("Create Server Service (metrics) skipped")
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func svcWatcherUsersFunc(ctx *CmdCtx) (service Service) {
|
||||
var err error
|
||||
|
||||
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
|
||||
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
|
||||
|
||||
if service, err = NewFileWatcherService("users", ctx.config.AuthenticationBackend.File.Path, provider, ctx.log); err != nil {
|
||||
ctx.log.WithError(err).Fatal("Create Watcher Service (users) returned error")
|
||||
}
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func connectionType(isTLS bool) string {
|
||||
if isTLS {
|
||||
return "TLS"
|
||||
}
|
||||
|
||||
return "non-TLS"
|
||||
}
|
||||
|
||||
func servicesRun(ctx *CmdCtx) {
|
||||
cctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
group, cctx := errgroup.WithContext(cctx)
|
||||
|
||||
defer cancel()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
defer signal.Stop(quit)
|
||||
|
||||
var (
|
||||
services []Service
|
||||
)
|
||||
|
||||
for _, serviceFunc := range []func(ctx *CmdCtx) Service{
|
||||
svcSvrMainFunc, svcSvrMetricsFunc,
|
||||
svcWatcherUsersFunc,
|
||||
} {
|
||||
if service := serviceFunc(ctx); service != nil {
|
||||
services = append(services, service)
|
||||
|
||||
group.Go(service.Run)
|
||||
}
|
||||
}
|
||||
|
||||
ctx.log.Info("Startup Complete")
|
||||
|
||||
select {
|
||||
case s := <-quit:
|
||||
switch s {
|
||||
case syscall.SIGINT:
|
||||
ctx.log.WithField("signal", "SIGINT").Debugf("Shutdown started due to signal")
|
||||
case syscall.SIGTERM:
|
||||
ctx.log.WithField("signal", "SIGTERM").Debugf("Shutdown started due to signal")
|
||||
}
|
||||
case <-cctx.Done():
|
||||
ctx.log.Debugf("Shutdown started due to context completion")
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
ctx.log.Infof("Shutting down")
|
||||
|
||||
wgShutdown := &sync.WaitGroup{}
|
||||
|
||||
for _, service := range services {
|
||||
go func() {
|
||||
service.Shutdown()
|
||||
|
||||
wgShutdown.Done()
|
||||
}()
|
||||
|
||||
wgShutdown.Add(1)
|
||||
}
|
||||
|
||||
wgShutdown.Wait()
|
||||
|
||||
var err error
|
||||
|
||||
if err = ctx.providers.StorageProvider.Close(); err != nil {
|
||||
ctx.log.WithError(err).Error("Error occurred closing database connections")
|
||||
}
|
||||
|
||||
if err = group.Wait(); err != nil {
|
||||
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
|
||||
}
|
||||
}
|
|
@ -83,12 +83,3 @@ const (
|
|||
tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'"
|
||||
tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'"
|
||||
)
|
||||
|
||||
const (
|
||||
connNonTLS = "non-TLS"
|
||||
connTLS = "TLS"
|
||||
)
|
||||
|
||||
const (
|
||||
fmtLogServerInit = "Initializing %s for %s connections on '%s' path '%s'"
|
||||
)
|
||||
|
|
|
@ -92,10 +92,10 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||
log := logging.Logger()
|
||||
|
||||
optsTemplatedFile := NewTemplatedFileOptions(&config)
|
||||
optsTemplatedFile := NewTemplatedFileOptions(config)
|
||||
|
||||
serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile)
|
||||
serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile)
|
||||
|
@ -104,7 +104,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
|||
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
||||
handlerLocales := newLocalesEmbeddedHandler()
|
||||
|
||||
bridge := middlewares.NewBridgeBuilder(config, providers).
|
||||
bridge := middlewares.NewBridgeBuilder(*config, providers).
|
||||
WithPreMiddlewares(middlewares.SecurityHeaders).Build()
|
||||
|
||||
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
|
||||
|
@ -141,11 +141,11 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
|||
r.GET("/api/"+file, handlerPublicHTML)
|
||||
}
|
||||
|
||||
middlewareAPI := middlewares.NewBridgeBuilder(config, providers).
|
||||
middlewareAPI := middlewares.NewBridgeBuilder(*config, providers).
|
||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||
Build()
|
||||
|
||||
middleware1FA := middlewares.NewBridgeBuilder(config, providers).
|
||||
middleware1FA := middlewares.NewBridgeBuilder(*config, providers).
|
||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||
WithPostMiddlewares(middlewares.Require1FA).
|
||||
Build()
|
||||
|
@ -162,7 +162,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
|||
for name, endpoint := range config.Server.Endpoints.Authz {
|
||||
uri := path.Join(pathAuthz, name)
|
||||
|
||||
authz := handlers.NewAuthzBuilder().WithConfig(&config).WithEndpointConfig(endpoint).Build()
|
||||
authz := handlers.NewAuthzBuilder().WithConfig(config).WithEndpointConfig(endpoint).Build()
|
||||
|
||||
handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler))
|
||||
|
||||
|
@ -268,7 +268,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
|||
}
|
||||
|
||||
if providers.OpenIDConnect != nil {
|
||||
bridgeOIDC := middlewares.NewBridgeBuilder(config, providers).WithPreMiddlewares(
|
||||
bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares(
|
||||
middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
|
||||
).Build()
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
@ -18,9 +17,9 @@ import (
|
|||
)
|
||||
|
||||
// CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers.
|
||||
func CreateDefaultServer(config schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, err error) {
|
||||
func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) {
|
||||
if err = providers.Templates.LoadTemplatedAssets(assets); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to load templated assets: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err)
|
||||
}
|
||||
|
||||
server = &fasthttp.Server{
|
||||
|
@ -38,15 +37,14 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
|||
address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port))
|
||||
|
||||
var (
|
||||
connectionType string
|
||||
connectionScheme string
|
||||
)
|
||||
|
||||
if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" {
|
||||
connectionType, connectionScheme = connTLS, schemeHTTPS
|
||||
isTLS, connectionScheme = true, schemeHTTPS
|
||||
|
||||
if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
|
||||
return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
|
||||
}
|
||||
|
||||
if len(config.Server.TLS.ClientCertificates) > 0 {
|
||||
|
@ -56,7 +54,7 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
|||
|
||||
for _, path := range config.Server.TLS.ClientCertificates {
|
||||
if cert, err = os.ReadFile(path); err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
|
||||
return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
|
||||
}
|
||||
|
||||
caCertPool.AppendCertsFromPEM(cert)
|
||||
|
@ -69,51 +67,51 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
|||
}
|
||||
|
||||
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||
}
|
||||
} else {
|
||||
connectionType, connectionScheme = connNonTLS, schemeHTTP
|
||||
connectionScheme = schemeHTTP
|
||||
|
||||
if listener, err = net.Listen("tcp", address); err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host,
|
||||
config.Server.Path, config.Server.Port); err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to configure healthcheck: %w", err)
|
||||
return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err)
|
||||
}
|
||||
|
||||
paths := []string{"/"}
|
||||
paths = []string{"/"}
|
||||
|
||||
if config.Server.Path != "" {
|
||||
paths = append(paths, config.Server.Path)
|
||||
}
|
||||
|
||||
logging.Logger().Infof(fmtLogServerInit, "server", connectionType, listener.Addr().String(), strings.Join(paths, "' and '"))
|
||||
|
||||
return server, listener, nil
|
||||
return server, listener, paths, isTLS, nil
|
||||
}
|
||||
|
||||
// CreateMetricsServer creates a metrics server.
|
||||
func CreateMetricsServer(config schema.TelemetryMetricsConfig) (server *fasthttp.Server, listener net.Listener, err error) {
|
||||
if listener, err = config.Address.Listener(); err != nil {
|
||||
return nil, nil, err
|
||||
func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) {
|
||||
if providers.Metrics == nil {
|
||||
return
|
||||
}
|
||||
|
||||
server = &fasthttp.Server{
|
||||
ErrorHandler: handleError(),
|
||||
NoDefaultServerHeader: true,
|
||||
Handler: handleMetrics(),
|
||||
ReadBufferSize: config.Buffers.Read,
|
||||
WriteBufferSize: config.Buffers.Write,
|
||||
ReadTimeout: config.Timeouts.Read,
|
||||
WriteTimeout: config.Timeouts.Write,
|
||||
IdleTimeout: config.Timeouts.Idle,
|
||||
ReadBufferSize: config.Telemetry.Metrics.Buffers.Read,
|
||||
WriteBufferSize: config.Telemetry.Metrics.Buffers.Write,
|
||||
ReadTimeout: config.Telemetry.Metrics.Timeouts.Read,
|
||||
WriteTimeout: config.Telemetry.Metrics.Timeouts.Write,
|
||||
IdleTimeout: config.Telemetry.Metrics.Timeouts.Idle,
|
||||
Logger: logging.LoggerPrintf(logrus.DebugLevel),
|
||||
}
|
||||
|
||||
logging.Logger().Infof(fmtLogServerInit, "server (metrics)", connNonTLS, listener.Addr().String(), "/metrics")
|
||||
if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
return server, listener, nil
|
||||
return server, listener, []string{"/metrics"}, false, nil
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS
|
|||
return nil, err
|
||||
}
|
||||
|
||||
s, listener, err := CreateDefaultServer(configuration, providers)
|
||||
s, listener, _, _, err := CreateDefaultServer(&configuration, providers)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -42,7 +42,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error
|
|||
90*time.Second,
|
||||
dockerEnvironment,
|
||||
"authelia-backend",
|
||||
[]string{"Initializing server for"})
|
||||
[]string{"Startup Complete"})
|
||||
}
|
||||
|
||||
func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {
|
||||
|
|
Loading…
Reference in New Issue