From aed9099ce2fa704c19ac351da32f31e87c54a1b5 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 17 Sep 2021 19:53:59 +1000 Subject: [PATCH] refactor: factorize startup checks (#2386) * refactor: factorize startup checks * refactor: address linting issues --- internal/authentication/file_user_provider.go | 6 ++ internal/authentication/ldap_user_provider.go | 33 +++---- .../ldap_user_provider_startup.go | 15 ++- .../authentication/ldap_user_provider_test.go | 37 ++++++-- internal/authentication/user_provider.go | 11 ++- internal/commands/root.go | 93 ++++++++++++------- internal/middlewares/types.go | 5 + internal/mocks/mock_notifier.go | 14 +-- internal/mocks/mock_user_provider.go | 15 +++ internal/notification/file_notifier.go | 16 ++-- internal/notification/notifier.go | 8 +- internal/notification/smtp_notifier.go | 22 +++-- internal/ntp/ntp.go | 30 ++++-- internal/ntp/ntp_test.go | 6 +- 14 files changed, 205 insertions(+), 106 deletions(-) diff --git a/internal/authentication/file_user_provider.go b/internal/authentication/file_user_provider.go index 74cf4ff86..9947c4f75 100644 --- a/internal/authentication/file_user_provider.go +++ b/internal/authentication/file_user_provider.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/asaskevich/govalidator" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" "github.com/authelia/authelia/v4/internal/configuration/schema" @@ -205,3 +206,8 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e return err } + +// StartupCheck implements the startup check provider interface. +func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) { + return nil +} diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index 00e8b8ff1..5e7bd645c 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -24,6 +24,8 @@ type LDAPUserProvider struct { logger *logrus.Logger connectionFactory LDAPConnectionFactory + disableResetPassword bool + // Automatically detected ldap features. supportExtensionPasswdModify bool @@ -41,25 +43,13 @@ type LDAPUserProvider struct { } // NewLDAPUserProvider creates a new instance of LDAPUserProvider. -func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider, err error) { - provider = newLDAPUserProvider(*configuration.LDAP, certPool, nil) +func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider) { + provider = newLDAPUserProvider(*configuration.LDAP, configuration.DisableResetPassword, certPool, nil) - err = provider.checkServer() - if err != nil { - return provider, err - } - - if !provider.supportExtensionPasswdModify && !configuration.DisableResetPassword && - provider.configuration.Implementation != schema.LDAPImplementationActiveDirectory { - provider.logger.Warnf("Your LDAP server implementation may not support a method for password hashing " + - "known to Authelia, it's strongly recommended you ensure your directory server hashes the password " + - "attribute when users reset their password via Authelia.") - } - - return provider, nil + return provider } -func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) { +func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) { if configuration.TLS == nil { configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS } @@ -79,11 +69,12 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura } provider = &LDAPUserProvider{ - configuration: configuration, - tlsConfig: tlsConfig, - dialOpts: dialOpts, - logger: logging.Logger(), - connectionFactory: factory, + configuration: configuration, + tlsConfig: tlsConfig, + dialOpts: dialOpts, + logger: logging.Logger(), + connectionFactory: factory, + disableResetPassword: disableResetPassword, } provider.parseDynamicUsersConfiguration() diff --git a/internal/authentication/ldap_user_provider_startup.go b/internal/authentication/ldap_user_provider_startup.go index f06985335..5c56df22c 100644 --- a/internal/authentication/ldap_user_provider_startup.go +++ b/internal/authentication/ldap_user_provider_startup.go @@ -4,9 +4,13 @@ import ( "strings" "github.com/go-ldap/ldap/v3" + "github.com/sirupsen/logrus" + + "github.com/authelia/authelia/v4/internal/configuration/schema" ) -func (p *LDAPUserProvider) checkServer() (err error) { +// StartupCheck implements the startup check provider interface. +func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) { conn, err := p.connect(p.configuration.User, p.configuration.Password) if err != nil { return err @@ -29,7 +33,7 @@ func (p *LDAPUserProvider) checkServer() (err error) { // Iterate the attribute values to see what the server supports. for _, attr := range sr.Entries[0].Attributes { if attr.Name == ldapSupportedExtensionAttribute { - p.logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", ")) + logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", ")) for _, oid := range attr.Values { if oid == ldapOIDPasswdModifyExtension { @@ -42,6 +46,13 @@ func (p *LDAPUserProvider) checkServer() (err error) { } } + if !p.supportExtensionPasswdModify && !p.disableResetPassword && + p.configuration.Implementation != schema.LDAPImplementationActiveDirectory { + logger.Warn("Your LDAP server implementation may not support a method for password hashing " + + "known to Authelia, it's strongly recommended you ensure your directory server hashes the password " + + "attribute when users reset their password via Authelia.") + } + return nil } diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go index 701c8564b..b9b739f38 100644 --- a/internal/authentication/ldap_user_provider_test.go +++ b/internal/authentication/ldap_user_provider_test.go @@ -12,6 +12,7 @@ import ( "golang.org/x/text/encoding/unicode" "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/utils" ) @@ -26,6 +27,7 @@ func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) { schema.LDAPAuthenticationBackendConfiguration{ URL: "ldap://127.0.0.1:389", }, + false, nil, mockFactory) @@ -55,6 +57,7 @@ func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) { schema.LDAPAuthenticationBackendConfiguration{ URL: "ldaps://127.0.0.1:389", }, + false, nil, mockFactory) @@ -83,6 +86,7 @@ func TestEscapeSpecialCharsFromUserInput(t *testing.T) { schema.LDAPAuthenticationBackendConfiguration{ URL: "ldaps://127.0.0.1:389", }, + false, nil, mockFactory) @@ -115,6 +119,7 @@ func TestEscapeSpecialCharsInGroupsFilter(t *testing.T) { URL: "ldaps://127.0.0.1:389", GroupsFilter: "(|(member={dn})(uid={username})(uid={input}))", }, + false, nil, mockFactory) @@ -179,6 +184,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -210,7 +216,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) assert.NoError(t, err) assert.True(t, ldapClient.supportExtensionPasswdModify) @@ -235,6 +241,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -266,7 +273,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) assert.NoError(t, err) assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -291,6 +298,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -298,7 +306,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) { DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()). Return(mockConn, errors.New("could not connect")) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) assert.EqualError(t, err, "could not connect") assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -323,6 +331,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -342,7 +351,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) assert.EqualError(t, err, "could not perform the search") assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -384,6 +393,7 @@ func TestShouldEscapeUserInput(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -416,6 +426,7 @@ func TestShouldCombineUsernameFilterAndUsersFilter(t *testing.T) { MailAttribute: "mail", DisplayNameAttribute: "displayName", }, + false, nil, mockFactory) @@ -463,6 +474,7 @@ func TestShouldNotCrashWhenGroupsAreNotRetrievedFromLDAP(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -532,6 +544,7 @@ func TestShouldNotCrashWhenEmailsAreNotRetrievedFromLDAP(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -594,6 +607,7 @@ func TestShouldReturnUsernameFromLDAP(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -665,6 +679,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -740,7 +755,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") @@ -767,6 +782,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -846,7 +862,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") @@ -873,6 +889,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -949,7 +966,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.checkServer() + err := ldapClient.StartupCheck(logging.Logger()) require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") @@ -975,6 +992,7 @@ func TestShouldCheckValidUserPassword(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -1042,6 +1060,7 @@ func TestShouldCheckInvalidUserPassword(t *testing.T) { AdditionalUsersDN: "ou=users", BaseDN: "dc=example,dc=com", }, + false, nil, mockFactory) @@ -1110,6 +1129,7 @@ func TestShouldCallStartTLSWhenEnabled(t *testing.T) { BaseDN: "dc=example,dc=com", StartTLS: true, }, + false, nil, mockFactory) @@ -1186,6 +1206,7 @@ func TestShouldParseDynamicConfiguration(t *testing.T) { BaseDN: "dc=example,dc=com", StartTLS: true, }, + false, nil, mockFactory) @@ -1224,6 +1245,7 @@ func TestShouldCallStartTLSWithInsecureSkipVerifyWhenSkipVerifyTrue(t *testing.T SkipVerify: true, }, }, + false, nil, mockFactory) @@ -1306,6 +1328,7 @@ func TestShouldReturnLDAPSAlreadySecuredWhenStartTLSAttempted(t *testing.T) { SkipVerify: true, }, }, + false, nil, mockFactory) diff --git a/internal/authentication/user_provider.go b/internal/authentication/user_provider.go index 9f721ce40..d70675cef 100644 --- a/internal/authentication/user_provider.go +++ b/internal/authentication/user_provider.go @@ -1,9 +1,14 @@ package authentication +import ( + "github.com/sirupsen/logrus" +) + // UserProvider is the interface for checking user password and // gathering user details. type UserProvider interface { - CheckUserPassword(username string, password string) (bool, error) - GetDetails(username string) (*UserDetails, error) - UpdatePassword(username string, newPassword string) error + CheckUserPassword(username string, password string) (valid bool, err error) + GetDetails(username string) (details *UserDetails, err error) + UpdatePassword(username string, newPassword string) (err error) + StartupCheck(logger *logrus.Logger) (err error) } diff --git a/internal/commands/root.go b/internal/commands/root.go index c36176d29..7d8f4be84 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -3,7 +3,9 @@ package commands import ( "fmt" "os" + "strings" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/authelia/authelia/v4/internal/authentication" @@ -78,13 +80,13 @@ func cmdRootRun(_ *cobra.Command, _ []string) { logger.Fatalf("Errors occurred provisioning providers.") } + doStartupChecks(config, &providers) + server.Start(*config, providers) } -//nolint:gocyclo // TODO: Consider refactoring time permitting. func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) { - logger := logging.Logger() - + // TODO: Adjust this so the CertPool can be used like a provider. autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory) if len(warnings) != 0 || len(errors) != 0 { return providers, warnings, errors @@ -100,6 +102,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers case config.Storage.Local != nil: storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) default: + // TODO: Add storage provider startup check and remove this. errors = append(errors, fmt.Errorf("unrecognized storage provider")) } @@ -112,12 +115,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers case config.AuthenticationBackend.File != nil: userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File) case config.AuthenticationBackend.LDAP != nil: - userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool) - if err != nil { - errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err)) - } - default: - errors = append(errors, fmt.Errorf("unrecognized user provider")) + userProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool) } var notifier notification.Notifier @@ -127,14 +125,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool) case config.Notifier.FileSystem != nil: notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) - default: - errors = append(errors, fmt.Errorf("unrecognized notifier provider")) - } - - if notifier != nil && !config.Notifier.DisableStartupCheck { - if _, err := notifier.StartupCheck(); err != nil { - errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err)) - } } var ntpProvider *ntp.Provider @@ -152,25 +142,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers errors = append(errors, err) } - var failed bool - if !config.NTP.DisableStartupCheck && authorizer.IsSecondFactorEnabled() { - failed, err = ntpProvider.StartupCheck() - - if err != nil { - logger.Errorf("Failed to check time against the NTP server: %+v", err) - } - - if failed { - if config.NTP.DisableFailure { - logger.Error("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server, this may cause issues in validating TOTP secrets") - } else { - logger.Fatal("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server") - } - } else { - logger.Debug("The system time is within the maximum desynchronization when compared to the time reported by the NTP server") - } - } - return middlewares.Providers{ Authorizer: authorizer, UserProvider: userProvider, @@ -182,3 +153,53 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers SessionProvider: sessionProvider, }, warnings, errors } + +func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) { + logger := logging.Logger() + + var ( + failures []string + err error + ) + + if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil { + logger.Errorf("Failure running the user provider startup check: %+v", err) + + failures = append(failures, "user") + } + + if err = doStartupCheck(logger, "notification", providers.Notifier, config.Notifier.DisableStartupCheck); err != nil { + logger.Errorf("Failure running the notification provider startup check: %+v", err) + + failures = append(failures, "notification") + } + + if !config.NTP.DisableStartupCheck && !providers.Authorizer.IsSecondFactorEnabled() { + logger.Debug("The NTP startup check was skipped due to there being no configured 2FA access control rules") + } else if err = doStartupCheck(logger, "ntp", providers.NTP, config.NTP.DisableStartupCheck); err != nil { + logger.Errorf("Failure running the user provider startup check: %+v", err) + + failures = append(failures, "ntp") + } + + if len(failures) != 0 { + logger.Fatalf("The following providers had fatal failures during startup: %s", strings.Join(failures, ", ")) + } +} + +func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) { + if disabled { + logger.Debugf("%s provider: startup check skipped as it is disabled", name) + return nil + } + + if provider == nil { + return fmt.Errorf("unrecognized provider or it is not configured properly") + } + + if err = provider.StartupCheck(logger); err != nil { + return err + } + + return nil +} diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index a04209193..a02d2cd70 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -28,6 +28,11 @@ type AutheliaCtx struct { Clock utils.Clock } +// ProviderWithStartupCheck represents a provider that has a startup check. +type ProviderWithStartupCheck interface { + StartupCheck(logger *logrus.Logger) (err error) +} + // Providers contain all provider provided to Authelia. type Providers struct { Authorizer *authorization.Authorizer diff --git a/internal/mocks/mock_notifier.go b/internal/mocks/mock_notifier.go index 77ecbf00a..aebba9a08 100644 --- a/internal/mocks/mock_notifier.go +++ b/internal/mocks/mock_notifier.go @@ -8,6 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" ) // MockNotifier is a mock of Notifier interface. @@ -48,16 +49,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go } // StartupCheck mocks base method. -func (m *MockNotifier) StartupCheck() (bool, error) { +func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartupCheck") - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "StartupCheck", arg0) + ret0, _ := ret[0].(error) + return ret0 } // StartupCheck indicates an expected call of StartupCheck. -func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call { +func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0) } diff --git a/internal/mocks/mock_user_provider.go b/internal/mocks/mock_user_provider.go index 5d65d4fd3..c05d223e8 100644 --- a/internal/mocks/mock_user_provider.go +++ b/internal/mocks/mock_user_provider.go @@ -8,6 +8,7 @@ import ( "reflect" "github.com/golang/mock/gomock" + "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/authentication" ) @@ -78,3 +79,17 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1) } + +// StartupCheck mocks base method. +func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartupCheck", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// StartupCheck indicates an expected call of StartupCheck. +func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0) +} diff --git a/internal/notification/file_notifier.go b/internal/notification/file_notifier.go index 4fad8a5d6..f0de2174e 100644 --- a/internal/notification/file_notifier.go +++ b/internal/notification/file_notifier.go @@ -7,6 +7,8 @@ import ( "path/filepath" "time" + "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/configuration/schema" ) @@ -22,28 +24,28 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File } } -// StartupCheck checks the file provider can write to the specified file. -func (n *FileNotifier) StartupCheck() (bool, error) { +// StartupCheck implements the startup check provider interface. +func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) { dir := filepath.Dir(n.path) if _, err := os.Stat(dir); err != nil { if os.IsNotExist(err) { if err = os.MkdirAll(dir, fileNotifierMode); err != nil { - return false, err + return err } } else { - return false, err + return err } } else if _, err = os.Stat(n.path); err != nil { if !os.IsNotExist(err) { - return false, err + return err } } if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil { - return false, err + return err } - return true, nil + return nil } // Send send a identity verification link to a user. diff --git a/internal/notification/notifier.go b/internal/notification/notifier.go index 1927fc30a..74210d494 100644 --- a/internal/notification/notifier.go +++ b/internal/notification/notifier.go @@ -1,7 +1,11 @@ package notification +import ( + "github.com/sirupsen/logrus" +) + // Notifier interface for sending the identity verification link. type Notifier interface { - Send(recipient, subject, body, htmlBody string) error - StartupCheck() (bool, error) + Send(recipient, subject, body, htmlBody string) (err error) + StartupCheck(logger *logrus.Logger) (err error) } diff --git a/internal/notification/smtp_notifier.go b/internal/notification/smtp_notifier.go index f223a3c1d..11c2cea1f 100644 --- a/internal/notification/smtp_notifier.go +++ b/internal/notification/smtp_notifier.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/utils" @@ -220,39 +222,39 @@ func (n *SMTPNotifier) cleanup() { } } -// StartupCheck checks the server is functioning correctly and the configuration is correct. -func (n *SMTPNotifier) StartupCheck() (bool, error) { +// StartupCheck implements the startup check provider interface. +func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) { if err := n.dial(); err != nil { - return false, err + return err } defer n.cleanup() if err := n.client.Hello(n.configuration.Identifier); err != nil { - return false, err + return err } if err := n.startTLS(); err != nil { - return false, err + return err } if err := n.auth(); err != nil { - return false, err + return err } if err := n.client.Mail(n.configuration.Sender); err != nil { - return false, err + return err } if err := n.client.Rcpt(n.configuration.StartupCheckAddress); err != nil { - return false, err + return err } if err := n.client.Reset(); err != nil { - return false, err + return err } - return true, nil + return nil } // Send is used to send an email to a recipient. diff --git a/internal/ntp/ntp.go b/internal/ntp/ntp.go index 2ca914d65..e30c4ca6f 100644 --- a/internal/ntp/ntp.go +++ b/internal/ntp/ntp.go @@ -2,10 +2,12 @@ package ntp import ( "encoding/binary" - "fmt" + "errors" "net" "time" + "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/utils" ) @@ -15,17 +17,21 @@ func NewProvider(config *schema.NTPConfiguration) *Provider { return &Provider{config} } -// StartupCheck checks if the system clock is not out of sync. -func (p *Provider) StartupCheck() (failed bool, err error) { +// StartupCheck implements the startup check provider interface. +func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) { conn, err := net.Dial("udp", p.config.Address) if err != nil { - return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err) + logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) + + return nil } defer conn.Close() if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err) + logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) + + return nil } version := ntpV4 @@ -36,7 +42,9 @@ func (p *Provider) StartupCheck() (failed bool, err error) { req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)} if err := binary.Write(conn, binary.BigEndian, req); err != nil { - return false, fmt.Errorf("could not write to the NTP server socket to validate the time desync: %w", err) + logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err) + + return nil } now := time.Now() @@ -44,12 +52,18 @@ func (p *Provider) StartupCheck() (failed bool, err error) { resp := &ntpPacket{} if err := binary.Read(conn, binary.BigEndian, resp); err != nil { - return false, fmt.Errorf("could not read from the NTP server socket to validate the time desync: %w", err) + logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err) + + return nil } maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync) ntpTime := ntpPacketToTime(resp) - return ntpIsOffsetTooLarge(maxOffset, now, ntpTime), nil + if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result { + return errors.New("the system clock is not synchronized accurately enough with the configured NTP server") + } + + return nil } diff --git a/internal/ntp/ntp_test.go b/internal/ntp/ntp_test.go index 3c578b962..95e230b8e 100644 --- a/internal/ntp/ntp_test.go +++ b/internal/ntp/ntp_test.go @@ -7,6 +7,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/validator" + "github.com/authelia/authelia/v4/internal/logging" ) func TestShouldCheckNTP(t *testing.T) { @@ -19,8 +20,7 @@ func TestShouldCheckNTP(t *testing.T) { sv := schema.NewStructValidator() validator.ValidateNTP(&config, sv) - NTP := NewProvider(&config) + ntp := NewProvider(&config) - checkfailed, _ := NTP.StartupCheck() - assert.Equal(t, false, checkfailed) + assert.NoError(t, ntp.StartupCheck(logging.Logger())) }