authelia/internal/commands/services.go

312 lines
8.1 KiB
Go

package commands
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
"github.com/valyala/fasthttp"
"golang.org/x/sync/errgroup"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/server"
)
// NewServerService creates a new ServerService with the appropriate logger etc.
func NewServerService(name string, server *fasthttp.Server, listener net.Listener, paths []string, isTLS bool, log *logrus.Logger) (service *ServerService) {
return &ServerService{
server: server,
listener: listener,
paths: paths,
isTLS: isTLS,
log: log.WithFields(map[string]any{"service": "server", "server": name}),
}
}
// NewFileWatcherService creates a new FileWatcherService with the appropriate logger etc.
func NewFileWatcherService(name, path string, reload ProviderReload, log *logrus.Logger) (service *FileWatcherService, err error) {
if path == "" {
return nil, fmt.Errorf("path must be specified")
}
var info os.FileInfo
if info, err = os.Stat(path); err != nil {
return nil, fmt.Errorf("error stating file '%s': %w", path, err)
}
if path, err = filepath.Abs(path); err != nil {
return nil, fmt.Errorf("error determining absolute path of file '%s': %w", path, err)
}
var watcher *fsnotify.Watcher
if watcher, err = fsnotify.NewWatcher(); err != nil {
return nil, err
}
entry := log.WithFields(map[string]any{"service": "watcher", "watcher": name})
if info.IsDir() {
service = &FileWatcherService{
watcher: watcher,
reload: reload,
log: entry,
directory: filepath.Clean(path),
}
} else {
service = &FileWatcherService{
watcher: watcher,
reload: reload,
log: entry,
directory: filepath.Dir(path),
file: filepath.Base(path),
}
}
if err = service.watcher.Add(service.directory); err != nil {
return nil, fmt.Errorf("failed to add path '%s' to watch list: %w", path, err)
}
return service, nil
}
// ProviderReload represents the required methods to support reloading a provider.
type ProviderReload interface {
Reload() (reloaded bool, err error)
}
// Service represents the required methods to support handling a service.
type Service interface {
Run() (err error)
Shutdown()
}
// ServerService is a Service which runs a webserver.
type ServerService struct {
server *fasthttp.Server
paths []string
isTLS bool
listener net.Listener
log *logrus.Entry
}
// Run the ServerService.
func (service *ServerService) Run() (err error) {
defer func() {
if r := recover(); r != nil {
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
}
}()
service.log.Infof(fmtLogServerListening, connectionType(service.isTLS), service.listener.Addr().String(), strings.Join(service.paths, "' and '"))
if err = service.server.Serve(service.listener); err != nil {
service.log.WithError(err).Error("Error returned attempting to serve requests")
return err
}
return nil
}
// Shutdown the ServerService.
func (service *ServerService) Shutdown() {
if err := service.server.Shutdown(); err != nil {
service.log.WithError(err).Error("Error occurred during shutdown")
}
}
// FileWatcherService is a Service that watches files for changes.
type FileWatcherService struct {
watcher *fsnotify.Watcher
reload ProviderReload
log *logrus.Entry
file string
directory string
}
// Run the FileWatcherService.
func (service *FileWatcherService) Run() (err error) {
defer func() {
if r := recover(); r != nil {
service.log.WithError(recoverErr(r)).Error("Critical error caught (recovered)")
}
}()
service.log.WithField("file", filepath.Join(service.directory, service.file)).Info("Watching for file changes to the file")
for {
select {
case event, ok := <-service.watcher.Events:
if !ok {
return nil
}
if service.file != "" && service.file != filepath.Base(event.Name) {
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Tracef("File modification detected to irrelevant file")
break
}
switch {
case event.Op&fsnotify.Write == fsnotify.Write, event.Op&fsnotify.Create == fsnotify.Create:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File modification was detected")
var reloaded bool
switch reloaded, err = service.reload.Reload(); {
case err != nil:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).WithError(err).Error("Error occurred during reload")
case reloaded:
service.log.WithField("file", event.Name).Info("Reloaded successfully")
default:
service.log.WithField("file", event.Name).Debug("Reload of was triggered but it was skipped")
}
case event.Op&fsnotify.Remove == fsnotify.Remove:
service.log.WithFields(map[string]any{"file": event.Name, "op": event.Op}).Debug("File remove was detected")
}
case err, ok := <-service.watcher.Errors:
if !ok {
return nil
}
service.log.WithError(err).Errorf("Error while watching files")
}
}
}
// Shutdown the FileWatcherService.
func (service *FileWatcherService) Shutdown() {
if err := service.watcher.Close(); err != nil {
service.log.WithError(err).Error("Error occurred during shutdown")
}
}
func svcSvrMainFunc(ctx *CmdCtx) (service Service) {
switch svr, listener, paths, isTLS, err := server.CreateDefaultServer(ctx.config, ctx.providers); {
case err != nil:
ctx.log.WithError(err).Fatal("Create Server Service (main) returned error")
case svr != nil && listener != nil:
service = NewServerService("main", svr, listener, paths, isTLS, ctx.log)
default:
ctx.log.Fatal("Create Server Service (main) failed")
}
return service
}
func svcSvrMetricsFunc(ctx *CmdCtx) (service Service) {
switch svr, listener, paths, isTLS, err := server.CreateMetricsServer(ctx.config, ctx.providers); {
case err != nil:
ctx.log.WithError(err).Fatal("Create Server Service (metrics) returned error")
case svr != nil && listener != nil:
service = NewServerService("metrics", svr, listener, paths, isTLS, ctx.log)
default:
ctx.log.Debug("Create Server Service (metrics) skipped")
}
return service
}
func svcWatcherUsersFunc(ctx *CmdCtx) (service Service) {
var err error
if ctx.config.AuthenticationBackend.File != nil && ctx.config.AuthenticationBackend.File.Watch {
provider := ctx.providers.UserProvider.(*authentication.FileUserProvider)
if service, err = NewFileWatcherService("users", ctx.config.AuthenticationBackend.File.Path, provider, ctx.log); err != nil {
ctx.log.WithError(err).Fatal("Create Watcher Service (users) returned error")
}
}
return service
}
func connectionType(isTLS bool) string {
if isTLS {
return "TLS"
}
return "non-TLS"
}
func servicesRun(ctx *CmdCtx) {
cctx, cancel := context.WithCancel(ctx)
group, cctx := errgroup.WithContext(cctx)
defer cancel()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(quit)
var (
services []Service
)
for _, serviceFunc := range []func(ctx *CmdCtx) Service{
svcSvrMainFunc, svcSvrMetricsFunc,
svcWatcherUsersFunc,
} {
if service := serviceFunc(ctx); service != nil {
services = append(services, service)
group.Go(service.Run)
}
}
ctx.log.Info("Startup Complete")
select {
case s := <-quit:
switch s {
case syscall.SIGINT:
ctx.log.WithField("signal", "SIGINT").Debugf("Shutdown started due to signal")
case syscall.SIGTERM:
ctx.log.WithField("signal", "SIGTERM").Debugf("Shutdown started due to signal")
}
case <-cctx.Done():
ctx.log.Debugf("Shutdown started due to context completion")
}
cancel()
ctx.log.Infof("Shutting down")
wgShutdown := &sync.WaitGroup{}
for _, service := range services {
go func() {
service.Shutdown()
wgShutdown.Done()
}()
wgShutdown.Add(1)
}
wgShutdown.Wait()
var err error
if err = ctx.providers.StorageProvider.Close(); err != nil {
ctx.log.WithError(err).Error("Error occurred closing database connections")
}
if err = group.Wait(); err != nil {
ctx.log.WithError(err).Errorf("Error occurred waiting for shutdown")
}
}