refactor(commands): services (#4914)
Misc refactoring of the services logic to simplify thepull/4916/head
parent
1a5178a8a5
commit
2888ee7f41
|
@ -772,3 +772,7 @@ Layouts:
|
||||||
ANSIC: Mon Jan _2 15:04:05 2006
|
ANSIC: Mon Jan _2 15:04:05 2006
|
||||||
Date: 2006-01-02`
|
Date: 2006-01-02`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fmtLogServerListening = "Server is listening for %s connections on '%s' path '%s'"
|
||||||
|
)
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/authentication"
|
"github.com/authelia/authelia/v4/internal/authentication"
|
||||||
"github.com/authelia/authelia/v4/internal/authorization"
|
"github.com/authelia/authelia/v4/internal/authorization"
|
||||||
|
@ -35,14 +34,8 @@ import (
|
||||||
func NewCmdCtx() *CmdCtx {
|
func NewCmdCtx() *CmdCtx {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
group, ctx := errgroup.WithContext(ctx)
|
|
||||||
|
|
||||||
return &CmdCtx{
|
return &CmdCtx{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
cancel: cancel,
|
|
||||||
group: group,
|
|
||||||
log: logging.Logger(),
|
log: logging.Logger(),
|
||||||
providers: middlewares.Providers{
|
providers: middlewares.Providers{
|
||||||
Random: &random.Cryptographical{},
|
Random: &random.Cryptographical{},
|
||||||
|
@ -55,9 +48,6 @@ func NewCmdCtx() *CmdCtx {
|
||||||
type CmdCtx struct {
|
type CmdCtx struct {
|
||||||
context.Context
|
context.Context
|
||||||
|
|
||||||
cancel context.CancelFunc
|
|
||||||
group *errgroup.Group
|
|
||||||
|
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
|
|
||||||
config *schema.Configuration
|
config *schema.Configuration
|
||||||
|
|
|
@ -2,21 +2,13 @@ package commands
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/authentication"
|
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
"github.com/authelia/authelia/v4/internal/model"
|
"github.com/authelia/authelia/v4/internal/model"
|
||||||
"github.com/authelia/authelia/v4/internal/server"
|
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,195 +87,11 @@ func (ctx *CmdCtx) RootRunE(_ *cobra.Command, _ []string) (err error) {
|
||||||
|
|
||||||
ctx.cconfig = nil
|
ctx.cconfig = nil
|
||||||
|
|
||||||
runServices(ctx)
|
servicesRun(ctx)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocyclo // Complexity is required in this function.
|
|
||||||
func runServices(ctx *CmdCtx) {
|
|
||||||
defer ctx.cancel()
|
|
||||||
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
|
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
defer signal.Stop(quit)
|
|
||||||
|
|
||||||
var (
|
|
||||||
mainServer, metricsServer *fasthttp.Server
|
|
||||||
mainListener, metricsListener net.Listener
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.group.Go(func() (err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
ctx.log.WithError(recoverErr(r)).Errorf("Server (main) critical error caught (recovered)")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if mainServer, mainListener, err = server.CreateDefaultServer(*ctx.config, ctx.providers); err != nil {
|
|
||||||
ctx.log.WithError(err).Error("Create Server (main) returned error")
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = mainServer.Serve(mainListener); err != nil {
|
|
||||||
ctx.log.WithError(err).Error("Server (main) returned error")
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx.group.Go(func() (err error) {
|
|
||||||
if ctx.providers.Metrics == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
ctx.log.WithError(recoverErr(r)).Errorf("Server (metrics) critical error caught (recovered)")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if metricsServer, metricsListener, err = server.CreateMetricsServer(ctx.config.Telemetry.Metrics); err != nil {
|
|
||||||
ctx.log.WithError(err).Error("Create Server (metrics) returned error")
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = metricsServer.Serve(metricsListener); err != nil {
|
|
||||||
ctx.log.WithError(err).Error("Server (metrics) returned error")
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
|
|
||||||
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
|
|
||||||
if watcher, err := runServiceFileWatcher(ctx, ctx.config.AuthenticationBackend.File.Path, provider); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("File Watcher (user database) start returned error")
|
|
||||||
} else {
|
|
||||||
defer func(watcher *fsnotify.Watcher) {
|
|
||||||
if err := watcher.Close(); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("File Watcher (user database) close returned error")
|
|
||||||
}
|
|
||||||
}(watcher)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case s := <-quit:
|
|
||||||
switch s {
|
|
||||||
case syscall.SIGINT:
|
|
||||||
ctx.log.Debugf("Shutdown started due to SIGINT")
|
|
||||||
case syscall.SIGQUIT:
|
|
||||||
ctx.log.Debugf("Shutdown started due to SIGQUIT")
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
ctx.log.Debugf("Shutdown started due to context completion")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.cancel()
|
|
||||||
|
|
||||||
ctx.log.Infof("Shutting down")
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if mainServer != nil {
|
|
||||||
if err = mainServer.Shutdown(); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("Error occurred shutting down the server")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if metricsServer != nil {
|
|
||||||
if err = metricsServer.Shutdown(); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("Error occurred shutting down the metrics server")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = ctx.providers.StorageProvider.Close(); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("Error occurred closing the database connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = ctx.group.Wait(); err != nil {
|
|
||||||
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ReloadFilter func(path string) (skipped bool)
|
|
||||||
|
|
||||||
type ProviderReload interface {
|
|
||||||
Reload() (reloaded bool, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runServiceFileWatcher(ctx *CmdCtx, path string, reload ProviderReload) (watcher *fsnotify.Watcher, err error) {
|
|
||||||
if watcher, err = fsnotify.NewWatcher(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
failed := make(chan struct{})
|
|
||||||
|
|
||||||
var directory, filename string
|
|
||||||
|
|
||||||
if path != "" {
|
|
||||||
directory, filename = filepath.Dir(path), filepath.Base(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.group.Go(func() error {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-failed:
|
|
||||||
return nil
|
|
||||||
case event, ok := <-watcher.Events:
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if filename != filepath.Base(event.Name) {
|
|
||||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Tracef("File modification detected to irrelevant file")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
|
|
||||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("File modification detected")
|
|
||||||
|
|
||||||
switch reloaded, err := reload.Reload(); {
|
|
||||||
case err != nil:
|
|
||||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).WithError(err).Error("Error occurred reloading file")
|
|
||||||
case reloaded:
|
|
||||||
ctx.log.WithField("file", event.Name).Info("Reloaded file successfully")
|
|
||||||
default:
|
|
||||||
ctx.log.WithField("file", event.Name).Debug("Reload of file was triggered but it was skipped")
|
|
||||||
}
|
|
||||||
case event.Op&fsnotify.Remove == fsnotify.Remove:
|
|
||||||
ctx.log.WithField("file", event.Name).WithField("op", event.Op).Debug("Remove of file was detected")
|
|
||||||
}
|
|
||||||
case err, ok := <-watcher.Errors:
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ctx.log.WithError(err).Errorf("Error while watching files")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := watcher.Add(directory); err != nil {
|
|
||||||
failed <- struct{}{}
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.log.WithField("directory", directory).WithField("file", filename).Debug("Directory is being watched for changes to the file")
|
|
||||||
|
|
||||||
return watcher, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func doStartupChecks(ctx *CmdCtx) {
|
func doStartupChecks(ctx *CmdCtx) {
|
||||||
var (
|
var (
|
||||||
failures []string
|
failures []string
|
||||||
|
|
|
@ -0,0 +1,311 @@
|
||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/authentication"
|
||||||
|
"github.com/authelia/authelia/v4/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewServerService creates a new ServerService with the appropriate logger etc.
|
||||||
|
func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) {
|
||||||
|
return &ServerService{
|
||||||
|
server: server,
|
||||||
|
listener: listener,
|
||||||
|
paths: paths,
|
||||||
|
isTLS: isTLS,
|
||||||
|
log: log.WithFields(map[string]any{"service": "server", "server": name}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc.
|
||||||
|
func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) {
|
||||||
|
if path == "" {
|
||||||
|
return nil, fmt.Errorf("path must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
var info os.FileInfo
|
||||||
|
|
||||||
|
if info, err = os.Stat(path); err != nil {
|
||||||
|
return nil, fmt.Errorf("error stating file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if path, err = filepath.Abs(path); err != nil {
|
||||||
|
return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var watcher *fsnotify.Watcher
|
||||||
|
|
||||||
|
if watcher, err = fsnotify.NewWatcher(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := log.WithFields(map[string]any{"service": "watcher", "watcher": name})
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
service = &FileWatcherService{
|
||||||
|
watcher: watcher,
|
||||||
|
reload: reload,
|
||||||
|
log: entry,
|
||||||
|
directory: filepath.Clean(path),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
service = &FileWatcherService{
|
||||||
|
watcher: watcher,
|
||||||
|
reload: reload,
|
||||||
|
log: entry,
|
||||||
|
directory: filepath.Dir(path),
|
||||||
|
file: filepath.Base(path),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = service.watcher.Add(service.directory); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderReload represents the required methods to support reloading a provider.
|
||||||
|
type ProviderReload interface {
|
||||||
|
Reload() (reloaded bool, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Service represents the required methods to support handling a service.
|
||||||
|
type Service interface {
|
||||||
|
Run() (err error)
|
||||||
|
Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerService is a Service which runs a webserver.
|
||||||
|
type ServerService struct {
|
||||||
|
server *fasthttp.Server
|
||||||
|
paths []string
|
||||||
|
isTLS bool
|
||||||
|
listener net.Listener
|
||||||
|
log *logrus.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the ServerService.
|
||||||
|
func (service *ServerService) Run() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
|
||||||
|
|
||||||
|
if err = service.server.Serve(service.listener); err != nil {
|
||||||
|
service.log.WithError(err).Error("Error returned attempting to serve requests")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the ServerService.
|
||||||
|
func (service *ServerService) Shutdown() {
|
||||||
|
if err := service.server.Shutdown(); err != nil {
|
||||||
|
service.log.WithError(err).Error("Error occurred during shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileWatcherService is a Service that watches files for changes.
|
||||||
|
type FileWatcherService struct {
|
||||||
|
watcher *fsnotify.Watcher
|
||||||
|
reload ProviderReload
|
||||||
|
|
||||||
|
log *logrus.Entry
|
||||||
|
file string
|
||||||
|
directory string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the FileWatcherService.
|
||||||
|
func (service *FileWatcherService) Run() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
service.log.WithField("file", filepath.Join(service.directory, service.file)).Info("Watching for file changes to the file")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-service.watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.file != "" && service.file != filepath.Base(event.Name) {
|
||||||
|
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Tracef("File modification detected to irrelevant file")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
|
||||||
|
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File modification was detected")
|
||||||
|
|
||||||
|
var reloaded bool
|
||||||
|
|
||||||
|
switch reloaded, err = service.reload.Reload(); {
|
||||||
|
case err != nil:
|
||||||
|
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).WithError(err).Error("Error occurred during reload")
|
||||||
|
case reloaded:
|
||||||
|
service.log.WithField("file", event.Name).Info("Reloaded successfully")
|
||||||
|
default:
|
||||||
|
service.log.WithField("file", event.Name).Debug("Reload of was triggered but it was skipped")
|
||||||
|
}
|
||||||
|
case event.Op&fsnotify.Remove == fsnotify.Remove:
|
||||||
|
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File remove was detected")
|
||||||
|
}
|
||||||
|
case err, ok := <-service.watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
service.log.WithError(err).Errorf("Error while watching files")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown the FileWatcherService.
|
||||||
|
func (service *FileWatcherService) Shutdown() {
|
||||||
|
if err := service.watcher.Close(); err != nil {
|
||||||
|
service.log.WithError(err).Error("Error occurred during shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func svcSvrMainFunc(ctx *CmdCtx) (service Service) {
|
||||||
|
switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.config, ctx.providers); {
|
||||||
|
case err != nil:
|
||||||
|
ctx.log.WithError(err).Fatal("Create Server Service (main) returned error")
|
||||||
|
case svr != nil && listener != nil:
|
||||||
|
service = NewServerService("main", svr, listener, paths, isTLS, ctx.log)
|
||||||
|
default:
|
||||||
|
ctx.log.Fatal("Create Server Service (main) failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
func svcSvrMetricsFunc(ctx *CmdCtx) (service Service) {
|
||||||
|
switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.config, ctx.providers); {
|
||||||
|
case err != nil:
|
||||||
|
ctx.log.WithError(err).Fatal("Create Server Service (metrics) returned error")
|
||||||
|
case svr != nil && listener != nil:
|
||||||
|
service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.log)
|
||||||
|
default:
|
||||||
|
ctx.log.Debug("Create Server Service (metrics) skipped")
|
||||||
|
}
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
func svcWatcherUsersFunc(ctx *CmdCtx) (service Service) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
|
||||||
|
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
|
||||||
|
|
||||||
|
if service, err = NewFileWatcherService("users", ctx.config.AuthenticationBackend.File.Path, provider, ctx.log); err != nil {
|
||||||
|
ctx.log.WithError(err).Fatal("Create Watcher Service (users) returned error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return service
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionType(isTLS bool) string {
|
||||||
|
if isTLS {
|
||||||
|
return "TLS"
|
||||||
|
}
|
||||||
|
|
||||||
|
return "non-TLS"
|
||||||
|
}
|
||||||
|
|
||||||
|
func servicesRun(ctx *CmdCtx) {
|
||||||
|
cctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
group, cctx := errgroup.WithContext(cctx)
|
||||||
|
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
defer signal.Stop(quit)
|
||||||
|
|
||||||
|
var (
|
||||||
|
services []Service
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, serviceFunc := range []func(ctx *CmdCtx) Service{
|
||||||
|
svcSvrMainFunc, svcSvrMetricsFunc,
|
||||||
|
svcWatcherUsersFunc,
|
||||||
|
} {
|
||||||
|
if service := serviceFunc(ctx); service != nil {
|
||||||
|
services = append(services, service)
|
||||||
|
|
||||||
|
group.Go(service.Run)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.log.Info("Startup Complete")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s := <-quit:
|
||||||
|
switch s {
|
||||||
|
case syscall.SIGINT:
|
||||||
|
ctx.log.WithField("signal", "SIGINT").Debugf("Shutdown started due to signal")
|
||||||
|
case syscall.SIGTERM:
|
||||||
|
ctx.log.WithField("signal", "SIGTERM").Debugf("Shutdown started due to signal")
|
||||||
|
}
|
||||||
|
case <-cctx.Done():
|
||||||
|
ctx.log.Debugf("Shutdown started due to context completion")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
ctx.log.Infof("Shutting down")
|
||||||
|
|
||||||
|
wgShutdown := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
for _, service := range services {
|
||||||
|
go func() {
|
||||||
|
service.Shutdown()
|
||||||
|
|
||||||
|
wgShutdown.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgShutdown.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgShutdown.Wait()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err = ctx.providers.StorageProvider.Close(); err != nil {
|
||||||
|
ctx.log.WithError(err).Error("Error occurred closing database connections")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = group.Wait(); err != nil {
|
||||||
|
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
|
||||||
|
}
|
||||||
|
}
|
|
@ -83,12 +83,3 @@ const (
|
||||||
tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'"
|
tmplCSPSwaggerNonce = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline' 'nonce-%s'; style-src 'self' 'nonce-%s'; base-uri 'self'"
|
||||||
tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'"
|
tmplCSPSwagger = "default-src 'self'; img-src 'self' https://validator.swagger.io data:; object-src 'none'; script-src 'self' 'unsafe-inline'; style-src 'self'; base-uri 'self'"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
connNonTLS = "non-TLS"
|
|
||||||
connTLS = "TLS"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
fmtLogServerInit = "Initializing %s for %s connections on '%s' path '%s'"
|
|
||||||
)
|
|
||||||
|
|
|
@ -92,10 +92,10 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocyclo
|
//nolint:gocyclo
|
||||||
func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
func handleRouter(config *schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||||
log := logging.Logger()
|
log := logging.Logger()
|
||||||
|
|
||||||
optsTemplatedFile := NewTemplatedFileOptions(&config)
|
optsTemplatedFile := NewTemplatedFileOptions(config)
|
||||||
|
|
||||||
serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile)
|
serveIndexHandler := ServeTemplatedFile(providers.Templates.GetAssetIndexTemplate(), optsTemplatedFile)
|
||||||
serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile)
|
serveOpenAPIHandler := ServeTemplatedOpenAPI(providers.Templates.GetAssetOpenAPIIndexTemplate(), optsTemplatedFile)
|
||||||
|
@ -104,7 +104,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
||||||
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
||||||
handlerLocales := newLocalesEmbeddedHandler()
|
handlerLocales := newLocalesEmbeddedHandler()
|
||||||
|
|
||||||
bridge := middlewares.NewBridgeBuilder(config, providers).
|
bridge := middlewares.NewBridgeBuilder(*config, providers).
|
||||||
WithPreMiddlewares(middlewares.SecurityHeaders).Build()
|
WithPreMiddlewares(middlewares.SecurityHeaders).Build()
|
||||||
|
|
||||||
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
|
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
|
||||||
|
@ -141,11 +141,11 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
||||||
r.GET("/api/"+file, handlerPublicHTML)
|
r.GET("/api/"+file, handlerPublicHTML)
|
||||||
}
|
}
|
||||||
|
|
||||||
middlewareAPI := middlewares.NewBridgeBuilder(config, providers).
|
middlewareAPI := middlewares.NewBridgeBuilder(*config, providers).
|
||||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||||
Build()
|
Build()
|
||||||
|
|
||||||
middleware1FA := middlewares.NewBridgeBuilder(config, providers).
|
middleware1FA := middlewares.NewBridgeBuilder(*config, providers).
|
||||||
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
WithPreMiddlewares(middlewares.SecurityHeaders, middlewares.SecurityHeadersNoStore, middlewares.SecurityHeadersCSPNone).
|
||||||
WithPostMiddlewares(middlewares.Require1FA).
|
WithPostMiddlewares(middlewares.Require1FA).
|
||||||
Build()
|
Build()
|
||||||
|
@ -162,7 +162,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
||||||
for name, endpoint := range config.Server.Endpoints.Authz {
|
for name, endpoint := range config.Server.Endpoints.Authz {
|
||||||
uri := path.Join(pathAuthz, name)
|
uri := path.Join(pathAuthz, name)
|
||||||
|
|
||||||
authz := handlers.NewAuthzBuilder().WithConfig(&config).WithEndpointConfig(endpoint).Build()
|
authz := handlers.NewAuthzBuilder().WithConfig(config).WithEndpointConfig(endpoint).Build()
|
||||||
|
|
||||||
handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler))
|
handler := middlewares.Wrap(metricsVRMW, bridge(authz.Handler))
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
if providers.OpenIDConnect != nil {
|
if providers.OpenIDConnect != nil {
|
||||||
bridgeOIDC := middlewares.NewBridgeBuilder(config, providers).WithPreMiddlewares(
|
bridgeOIDC := middlewares.NewBridgeBuilder(*config, providers).WithPreMiddlewares(
|
||||||
middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
|
middlewares.SecurityHeaders, middlewares.SecurityHeadersCSPNoneOpenIDConnect, middlewares.SecurityHeadersNoStore,
|
||||||
).Build()
|
).Build()
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
@ -18,9 +17,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers.
|
// CreateDefaultServer Create Authelia's internal webserver with the given configuration and providers.
|
||||||
func CreateDefaultServer(config schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, err error) {
|
func CreateDefaultServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, err error) {
|
||||||
if err = providers.Templates.LoadTemplatedAssets(assets); err != nil {
|
if err = providers.Templates.LoadTemplatedAssets(assets); err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to load templated assets: %w", err)
|
return nil, nil, nil, false, fmt.Errorf("failed to load templated assets: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
server = &fasthttp.Server{
|
server = &fasthttp.Server{
|
||||||
|
@ -38,15 +37,14 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
||||||
address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port))
|
address := net.JoinHostPort(config.Server.Host, strconv.Itoa(config.Server.Port))
|
||||||
|
|
||||||
var (
|
var (
|
||||||
connectionType string
|
|
||||||
connectionScheme string
|
connectionScheme string
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" {
|
if config.Server.TLS.Certificate != "" && config.Server.TLS.Key != "" {
|
||||||
connectionType, connectionScheme = connTLS, schemeHTTPS
|
isTLS, connectionScheme = true, schemeHTTPS
|
||||||
|
|
||||||
if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil {
|
if err = server.AppendCert(config.Server.TLS.Certificate, config.Server.TLS.Key); err != nil {
|
||||||
return nil, nil, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
|
return nil, nil, nil, false, fmt.Errorf("unable to load tls server certificate '%s' or private key '%s': %w", config.Server.TLS.Certificate, config.Server.TLS.Key, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Server.TLS.ClientCertificates) > 0 {
|
if len(config.Server.TLS.ClientCertificates) > 0 {
|
||||||
|
@ -56,7 +54,7 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
||||||
|
|
||||||
for _, path := range config.Server.TLS.ClientCertificates {
|
for _, path := range config.Server.TLS.ClientCertificates {
|
||||||
if cert, err = os.ReadFile(path); err != nil {
|
if cert, err = os.ReadFile(path); err != nil {
|
||||||
return nil, nil, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
|
return nil, nil, nil, false, fmt.Errorf("unable to load tls client certificate '%s': %w", path, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
caCertPool.AppendCertsFromPEM(cert)
|
caCertPool.AppendCertsFromPEM(cert)
|
||||||
|
@ -69,51 +67,51 @@ func CreateDefaultServer(config schema.Configuration, providers middlewares.Prov
|
||||||
}
|
}
|
||||||
|
|
||||||
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
|
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
|
||||||
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
connectionType, connectionScheme = connNonTLS, schemeHTTP
|
connectionScheme = schemeHTTP
|
||||||
|
|
||||||
if listener, err = net.Listen("tcp", address); err != nil {
|
if listener, err = net.Listen("tcp", address); err != nil {
|
||||||
return nil, nil, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
return nil, nil, nil, false, fmt.Errorf("unable to initialize tcp listener: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host,
|
if err = writeHealthCheckEnv(config.Server.DisableHealthcheck, connectionScheme, config.Server.Host,
|
||||||
config.Server.Path, config.Server.Port); err != nil {
|
config.Server.Path, config.Server.Port); err != nil {
|
||||||
return nil, nil, fmt.Errorf("unable to configure healthcheck: %w", err)
|
return nil, nil, nil, false, fmt.Errorf("unable to configure healthcheck: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paths := []string{"/"}
|
paths = []string{"/"}
|
||||||
|
|
||||||
if config.Server.Path != "" {
|
if config.Server.Path != "" {
|
||||||
paths = append(paths, config.Server.Path)
|
paths = append(paths, config.Server.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Logger().Infof(fmtLogServerInit, "server", connectionType, listener.Addr().String(), strings.Join(paths, "' and '"))
|
return server, listener, paths, isTLS, nil
|
||||||
|
|
||||||
return server, listener, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMetricsServer creates a metrics server.
|
// CreateMetricsServer creates a metrics server.
|
||||||
func CreateMetricsServer(config schema.TelemetryMetricsConfig) (server *fasthttp.Server, listener net.Listener, err error) {
|
func CreateMetricsServer(config *schema.Configuration, providers middlewares.Providers) (server *fasthttp.Server, listener net.Listener, paths []string, tls bool, err error) {
|
||||||
if listener, err = config.Address.Listener(); err != nil {
|
if providers.Metrics == nil {
|
||||||
return nil, nil, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
server = &fasthttp.Server{
|
server = &fasthttp.Server{
|
||||||
ErrorHandler: handleError(),
|
ErrorHandler: handleError(),
|
||||||
NoDefaultServerHeader: true,
|
NoDefaultServerHeader: true,
|
||||||
Handler: handleMetrics(),
|
Handler: handleMetrics(),
|
||||||
ReadBufferSize: config.Buffers.Read,
|
ReadBufferSize: config.Telemetry.Metrics.Buffers.Read,
|
||||||
WriteBufferSize: config.Buffers.Write,
|
WriteBufferSize: config.Telemetry.Metrics.Buffers.Write,
|
||||||
ReadTimeout: config.Timeouts.Read,
|
ReadTimeout: config.Telemetry.Metrics.Timeouts.Read,
|
||||||
WriteTimeout: config.Timeouts.Write,
|
WriteTimeout: config.Telemetry.Metrics.Timeouts.Write,
|
||||||
IdleTimeout: config.Timeouts.Idle,
|
IdleTimeout: config.Telemetry.Metrics.Timeouts.Idle,
|
||||||
Logger: logging.LoggerPrintf(logrus.DebugLevel),
|
Logger: logging.LoggerPrintf(logrus.DebugLevel),
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Logger().Infof(fmtLogServerInit, "server (metrics)", connNonTLS, listener.Addr().String(), "/metrics")
|
if listener, err = config.Telemetry.Metrics.Address.Listener(); err != nil {
|
||||||
|
return nil, nil, nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
return server, listener, nil
|
return server, listener, []string{"/metrics"}, false, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,7 +152,7 @@ func NewTLSServerContext(configuration schema.Configuration) (serverContext *TLS
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, listener, err := CreateDefaultServer(configuration, providers)
|
s, listener, _, _, err := CreateDefaultServer(&configuration, providers)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -42,7 +42,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error
|
||||||
90*time.Second,
|
90*time.Second,
|
||||||
dockerEnvironment,
|
dockerEnvironment,
|
||||||
"authelia-backend",
|
"authelia-backend",
|
||||||
[]string{"Initializing server for"})
|
[]string{"Startup Complete"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {
|
func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {
|
||||||
|
|
Loading…
Reference in New Issue