refactor: factorize startup checks (#2386)

* refactor: factorize startup checks

* refactor: address linting issues
pull/2387/head
James Elliott 2021-09-17 19:53:59 +10:00 committed by GitHub
parent 8e4dc91b81
commit aed9099ce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 205 additions and 106 deletions

View File

@ -9,6 +9,7 @@ import (
"sync" "sync"
"github.com/asaskevich/govalidator" "github.com/asaskevich/govalidator"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
@ -205,3 +206,8 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
return err return err
} }
// StartupCheck implements the startup check provider interface.
func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) {
return nil
}

View File

@ -24,6 +24,8 @@ type LDAPUserProvider struct {
logger *logrus.Logger logger *logrus.Logger
connectionFactory LDAPConnectionFactory connectionFactory LDAPConnectionFactory
disableResetPassword bool
// Automatically detected ldap features. // Automatically detected ldap features.
supportExtensionPasswdModify bool supportExtensionPasswdModify bool
@ -41,25 +43,13 @@ type LDAPUserProvider struct {
} }
// NewLDAPUserProvider creates a new instance of LDAPUserProvider. // NewLDAPUserProvider creates a new instance of LDAPUserProvider.
func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider, err error) { func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider) {
provider = newLDAPUserProvider(*configuration.LDAP, certPool, nil) provider = newLDAPUserProvider(*configuration.LDAP, configuration.DisableResetPassword, certPool, nil)
err = provider.checkServer() return provider
if err != nil {
return provider, err
} }
if !provider.supportExtensionPasswdModify && !configuration.DisableResetPassword && func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
provider.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
provider.logger.Warnf("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.")
}
return provider, nil
}
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
if configuration.TLS == nil { if configuration.TLS == nil {
configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS
} }
@ -84,6 +74,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura
dialOpts: dialOpts, dialOpts: dialOpts,
logger: logging.Logger(), logger: logging.Logger(),
connectionFactory: factory, connectionFactory: factory,
disableResetPassword: disableResetPassword,
} }
provider.parseDynamicUsersConfiguration() provider.parseDynamicUsersConfiguration()

View File

@ -4,9 +4,13 @@ import (
"strings" "strings"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
) )
func (p *LDAPUserProvider) checkServer() (err error) { // StartupCheck implements the startup check provider interface.
func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password) conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil { if err != nil {
return err return err
@ -29,7 +33,7 @@ func (p *LDAPUserProvider) checkServer() (err error) {
// Iterate the attribute values to see what the server supports. // Iterate the attribute values to see what the server supports.
for _, attr := range sr.Entries[0].Attributes { for _, attr := range sr.Entries[0].Attributes {
if attr.Name == ldapSupportedExtensionAttribute { if attr.Name == ldapSupportedExtensionAttribute {
p.logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", ")) logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
for _, oid := range attr.Values { for _, oid := range attr.Values {
if oid == ldapOIDPasswdModifyExtension { if oid == ldapOIDPasswdModifyExtension {
@ -42,6 +46,13 @@ func (p *LDAPUserProvider) checkServer() (err error) {
} }
} }
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
p.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
logger.Warn("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.")
}
return nil return nil
} }

View File

@ -12,6 +12,7 @@ import (
"golang.org/x/text/encoding/unicode" "golang.org/x/text/encoding/unicode"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -26,6 +27,7 @@ func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{ schema.LDAPAuthenticationBackendConfiguration{
URL: "ldap://127.0.0.1:389", URL: "ldap://127.0.0.1:389",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -55,6 +57,7 @@ func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{ schema.LDAPAuthenticationBackendConfiguration{
URL: "ldaps://127.0.0.1:389", URL: "ldaps://127.0.0.1:389",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -83,6 +86,7 @@ func TestEscapeSpecialCharsFromUserInput(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{ schema.LDAPAuthenticationBackendConfiguration{
URL: "ldaps://127.0.0.1:389", URL: "ldaps://127.0.0.1:389",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -115,6 +119,7 @@ func TestEscapeSpecialCharsInGroupsFilter(t *testing.T) {
URL: "ldaps://127.0.0.1:389", URL: "ldaps://127.0.0.1:389",
GroupsFilter: "(|(member={dn})(uid={username})(uid={input}))", GroupsFilter: "(|(member={dn})(uid={username})(uid={input}))",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -179,6 +184,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -210,7 +216,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose) gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, ldapClient.supportExtensionPasswdModify) assert.True(t, ldapClient.supportExtensionPasswdModify)
@ -235,6 +241,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -266,7 +273,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose) gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, ldapClient.supportExtensionPasswdModify) assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -291,6 +298,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -298,7 +306,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()). DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()).
Return(mockConn, errors.New("could not connect")) Return(mockConn, errors.New("could not connect"))
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
assert.EqualError(t, err, "could not connect") assert.EqualError(t, err, "could not connect")
assert.False(t, ldapClient.supportExtensionPasswdModify) assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -323,6 +331,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -342,7 +351,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose) gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
assert.EqualError(t, err, "could not perform the search") assert.EqualError(t, err, "could not perform the search")
assert.False(t, ldapClient.supportExtensionPasswdModify) assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -384,6 +393,7 @@ func TestShouldEscapeUserInput(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -416,6 +426,7 @@ func TestShouldCombineUsernameFilterAndUsersFilter(t *testing.T) {
MailAttribute: "mail", MailAttribute: "mail",
DisplayNameAttribute: "displayName", DisplayNameAttribute: "displayName",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -463,6 +474,7 @@ func TestShouldNotCrashWhenGroupsAreNotRetrievedFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -532,6 +544,7 @@ func TestShouldNotCrashWhenEmailsAreNotRetrievedFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -594,6 +607,7 @@ func TestShouldReturnUsernameFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -665,6 +679,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -740,7 +755,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err) require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password") err = ldapClient.UpdatePassword("john", "password")
@ -767,6 +782,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -846,7 +862,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err) require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password") err = ldapClient.UpdatePassword("john", "password")
@ -873,6 +889,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -949,7 +966,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer() err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err) require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password") err = ldapClient.UpdatePassword("john", "password")
@ -975,6 +992,7 @@ func TestShouldCheckValidUserPassword(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -1042,6 +1060,7 @@ func TestShouldCheckInvalidUserPassword(t *testing.T) {
AdditionalUsersDN: "ou=users", AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -1110,6 +1129,7 @@ func TestShouldCallStartTLSWhenEnabled(t *testing.T) {
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
StartTLS: true, StartTLS: true,
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -1186,6 +1206,7 @@ func TestShouldParseDynamicConfiguration(t *testing.T) {
BaseDN: "dc=example,dc=com", BaseDN: "dc=example,dc=com",
StartTLS: true, StartTLS: true,
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -1224,6 +1245,7 @@ func TestShouldCallStartTLSWithInsecureSkipVerifyWhenSkipVerifyTrue(t *testing.T
SkipVerify: true, SkipVerify: true,
}, },
}, },
false,
nil, nil,
mockFactory) mockFactory)
@ -1306,6 +1328,7 @@ func TestShouldReturnLDAPSAlreadySecuredWhenStartTLSAttempted(t *testing.T) {
SkipVerify: true, SkipVerify: true,
}, },
}, },
false,
nil, nil,
mockFactory) mockFactory)

View File

@ -1,9 +1,14 @@
package authentication package authentication
import (
"github.com/sirupsen/logrus"
)
// UserProvider is the interface for checking user password and // UserProvider is the interface for checking user password and
// gathering user details. // gathering user details.
type UserProvider interface { type UserProvider interface {
CheckUserPassword(username string, password string) (bool, error) CheckUserPassword(username string, password string) (valid bool, err error)
GetDetails(username string) (*UserDetails, error) GetDetails(username string) (details *UserDetails, err error)
UpdatePassword(username string, newPassword string) error UpdatePassword(username string, newPassword string) (err error)
StartupCheck(logger *logrus.Logger) (err error)
} }

View File

@ -3,7 +3,9 @@ package commands
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authentication"
@ -78,13 +80,13 @@ func cmdRootRun(_ *cobra.Command, _ []string) {
logger.Fatalf("Errors occurred provisioning providers.") logger.Fatalf("Errors occurred provisioning providers.")
} }
doStartupChecks(config, &providers)
server.Start(*config, providers) server.Start(*config, providers)
} }
//nolint:gocyclo // TODO: Consider refactoring time permitting.
func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) { func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) {
logger := logging.Logger() // TODO: Adjust this so the CertPool can be used like a provider.
autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory) autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory)
if len(warnings) != 0 || len(errors) != 0 { if len(warnings) != 0 || len(errors) != 0 {
return providers, warnings, errors return providers, warnings, errors
@ -100,6 +102,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
case config.Storage.Local != nil: case config.Storage.Local != nil:
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
default: default:
// TODO: Add storage provider startup check and remove this.
errors = append(errors, fmt.Errorf("unrecognized storage provider")) errors = append(errors, fmt.Errorf("unrecognized storage provider"))
} }
@ -112,12 +115,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
case config.AuthenticationBackend.File != nil: case config.AuthenticationBackend.File != nil:
userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File) userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
case config.AuthenticationBackend.LDAP != nil: case config.AuthenticationBackend.LDAP != nil:
userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool) userProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
if err != nil {
errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err))
}
default:
errors = append(errors, fmt.Errorf("unrecognized user provider"))
} }
var notifier notification.Notifier var notifier notification.Notifier
@ -127,14 +125,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool) notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool)
case config.Notifier.FileSystem != nil: case config.Notifier.FileSystem != nil:
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
default:
errors = append(errors, fmt.Errorf("unrecognized notifier provider"))
}
if notifier != nil && !config.Notifier.DisableStartupCheck {
if _, err := notifier.StartupCheck(); err != nil {
errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err))
}
} }
var ntpProvider *ntp.Provider var ntpProvider *ntp.Provider
@ -152,25 +142,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
errors = append(errors, err) errors = append(errors, err)
} }
var failed bool
if !config.NTP.DisableStartupCheck && authorizer.IsSecondFactorEnabled() {
failed, err = ntpProvider.StartupCheck()
if err != nil {
logger.Errorf("Failed to check time against the NTP server: %+v", err)
}
if failed {
if config.NTP.DisableFailure {
logger.Error("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server, this may cause issues in validating TOTP secrets")
} else {
logger.Fatal("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server")
}
} else {
logger.Debug("The system time is within the maximum desynchronization when compared to the time reported by the NTP server")
}
}
return middlewares.Providers{ return middlewares.Providers{
Authorizer: authorizer, Authorizer: authorizer,
UserProvider: userProvider, UserProvider: userProvider,
@ -182,3 +153,53 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
SessionProvider: sessionProvider, SessionProvider: sessionProvider,
}, warnings, errors }, warnings, errors
} }
func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) {
logger := logging.Logger()
var (
failures []string
err error
)
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
logger.Errorf("Failure running the user provider startup check: %+v", err)
failures = append(failures, "user")
}
if err = doStartupCheck(logger, "notification", providers.Notifier, config.Notifier.DisableStartupCheck); err != nil {
logger.Errorf("Failure running the notification provider startup check: %+v", err)
failures = append(failures, "notification")
}
if !config.NTP.DisableStartupCheck && !providers.Authorizer.IsSecondFactorEnabled() {
logger.Debug("The NTP startup check was skipped due to there being no configured 2FA access control rules")
} else if err = doStartupCheck(logger, "ntp", providers.NTP, config.NTP.DisableStartupCheck); err != nil {
logger.Errorf("Failure running the user provider startup check: %+v", err)
failures = append(failures, "ntp")
}
if len(failures) != 0 {
logger.Fatalf("The following providers had fatal failures during startup: %s", strings.Join(failures, ", "))
}
}
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
if disabled {
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
return nil
}
if provider == nil {
return fmt.Errorf("unrecognized provider or it is not configured properly")
}
if err = provider.StartupCheck(logger); err != nil {
return err
}
return nil
}

View File

@ -28,6 +28,11 @@ type AutheliaCtx struct {
Clock utils.Clock Clock utils.Clock
} }
// ProviderWithStartupCheck represents a provider that has a startup check.
type ProviderWithStartupCheck interface {
StartupCheck(logger *logrus.Logger) (err error)
}
// Providers contain all provider provided to Authelia. // Providers contain all provider provided to Authelia.
type Providers struct { type Providers struct {
Authorizer *authorization.Authorizer Authorizer *authorization.Authorizer

View File

@ -8,6 +8,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
) )
// MockNotifier is a mock of Notifier interface. // MockNotifier is a mock of Notifier interface.
@ -48,16 +49,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go
} }
// StartupCheck mocks base method. // StartupCheck mocks base method.
func (m *MockNotifier) StartupCheck() (bool, error) { func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck") ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// StartupCheck indicates an expected call of StartupCheck. // StartupCheck indicates an expected call of StartupCheck.
func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call { func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0)
} }

View File

@ -8,6 +8,7 @@ import (
"reflect" "reflect"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authentication"
) )
@ -78,3 +79,17 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) *
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1)
} }
// StartupCheck mocks base method.
func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0)
}

View File

@ -7,6 +7,8 @@ import (
"path/filepath" "path/filepath"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
) )
@ -22,28 +24,28 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File
} }
} }
// StartupCheck checks the file provider can write to the specified file. // StartupCheck implements the startup check provider interface.
func (n *FileNotifier) StartupCheck() (bool, error) { func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) {
dir := filepath.Dir(n.path) dir := filepath.Dir(n.path)
if _, err := os.Stat(dir); err != nil { if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
if err = os.MkdirAll(dir, fileNotifierMode); err != nil { if err = os.MkdirAll(dir, fileNotifierMode); err != nil {
return false, err return err
} }
} else { } else {
return false, err return err
} }
} else if _, err = os.Stat(n.path); err != nil { } else if _, err = os.Stat(n.path); err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return false, err return err
} }
} }
if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil { if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil {
return false, err return err
} }
return true, nil return nil
} }
// Send send a identity verification link to a user. // Send send a identity verification link to a user.

View File

@ -1,7 +1,11 @@
package notification package notification
import (
"github.com/sirupsen/logrus"
)
// Notifier interface for sending the identity verification link. // Notifier interface for sending the identity verification link.
type Notifier interface { type Notifier interface {
Send(recipient, subject, body, htmlBody string) error Send(recipient, subject, body, htmlBody string) (err error)
StartupCheck() (bool, error) StartupCheck(logger *logrus.Logger) (err error)
} }

View File

@ -10,6 +10,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
@ -220,39 +222,39 @@ func (n *SMTPNotifier) cleanup() {
} }
} }
// StartupCheck checks the server is functioning correctly and the configuration is correct. // StartupCheck implements the startup check provider interface.
func (n *SMTPNotifier) StartupCheck() (bool, error) { func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) {
if err := n.dial(); err != nil { if err := n.dial(); err != nil {
return false, err return err
} }
defer n.cleanup() defer n.cleanup()
if err := n.client.Hello(n.configuration.Identifier); err != nil { if err := n.client.Hello(n.configuration.Identifier); err != nil {
return false, err return err
} }
if err := n.startTLS(); err != nil { if err := n.startTLS(); err != nil {
return false, err return err
} }
if err := n.auth(); err != nil { if err := n.auth(); err != nil {
return false, err return err
} }
if err := n.client.Mail(n.configuration.Sender); err != nil { if err := n.client.Mail(n.configuration.Sender); err != nil {
return false, err return err
} }
if err := n.client.Rcpt(n.configuration.StartupCheckAddress); err != nil { if err := n.client.Rcpt(n.configuration.StartupCheckAddress); err != nil {
return false, err return err
} }
if err := n.client.Reset(); err != nil { if err := n.client.Reset(); err != nil {
return false, err return err
} }
return true, nil return nil
} }
// Send is used to send an email to a recipient. // Send is used to send an email to a recipient.

View File

@ -2,10 +2,12 @@ package ntp
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "errors"
"net" "net"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -15,17 +17,21 @@ func NewProvider(config *schema.NTPConfiguration) *Provider {
return &Provider{config} return &Provider{config}
} }
// StartupCheck checks if the system clock is not out of sync. // StartupCheck implements the startup check provider interface.
func (p *Provider) StartupCheck() (failed bool, err error) { func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
conn, err := net.Dial("udp", p.config.Address) conn, err := net.Dial("udp", p.config.Address)
if err != nil { if err != nil {
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err) logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
} }
defer conn.Close() defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err) logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
} }
version := ntpV4 version := ntpV4
@ -36,7 +42,9 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)} req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)}
if err := binary.Write(conn, binary.BigEndian, req); err != nil { if err := binary.Write(conn, binary.BigEndian, req); err != nil {
return false, fmt.Errorf("could not write to the NTP server socket to validate the time desync: %w", err) logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
} }
now := time.Now() now := time.Now()
@ -44,12 +52,18 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
resp := &ntpPacket{} resp := &ntpPacket{}
if err := binary.Read(conn, binary.BigEndian, resp); err != nil { if err := binary.Read(conn, binary.BigEndian, resp); err != nil {
return false, fmt.Errorf("could not read from the NTP server socket to validate the time desync: %w", err) logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
} }
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync) maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
ntpTime := ntpPacketToTime(resp) ntpTime := ntpPacketToTime(resp)
return ntpIsOffsetTooLarge(maxOffset, now, ntpTime), nil if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result {
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
}
return nil
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator" "github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/logging"
) )
func TestShouldCheckNTP(t *testing.T) { func TestShouldCheckNTP(t *testing.T) {
@ -19,8 +20,7 @@ func TestShouldCheckNTP(t *testing.T) {
sv := schema.NewStructValidator() sv := schema.NewStructValidator()
validator.ValidateNTP(&config, sv) validator.ValidateNTP(&config, sv)
NTP := NewProvider(&config) ntp := NewProvider(&config)
checkfailed, _ := NTP.StartupCheck() assert.NoError(t, ntp.StartupCheck(logging.Logger()))
assert.Equal(t, false, checkfailed)
} }