refactor: factorize startup checks (#2386)
* refactor: factorize startup checks * refactor: address linting issuespull/2387/head
parent
8e4dc91b81
commit
aed9099ce2
|
@ -9,6 +9,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/asaskevich/govalidator"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
|
@ -205,3 +206,8 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
|
|||
|
||||
return err
|
||||
}
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -24,6 +24,8 @@ type LDAPUserProvider struct {
|
|||
logger *logrus.Logger
|
||||
connectionFactory LDAPConnectionFactory
|
||||
|
||||
disableResetPassword bool
|
||||
|
||||
// Automatically detected ldap features.
|
||||
supportExtensionPasswdModify bool
|
||||
|
||||
|
@ -41,25 +43,13 @@ type LDAPUserProvider struct {
|
|||
}
|
||||
|
||||
// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
|
||||
func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider, err error) {
|
||||
provider = newLDAPUserProvider(*configuration.LDAP, certPool, nil)
|
||||
func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider) {
|
||||
provider = newLDAPUserProvider(*configuration.LDAP, configuration.DisableResetPassword, certPool, nil)
|
||||
|
||||
err = provider.checkServer()
|
||||
if err != nil {
|
||||
return provider, err
|
||||
}
|
||||
|
||||
if !provider.supportExtensionPasswdModify && !configuration.DisableResetPassword &&
|
||||
provider.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
|
||||
provider.logger.Warnf("Your LDAP server implementation may not support a method for password hashing " +
|
||||
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
|
||||
"attribute when users reset their password via Authelia.")
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
return provider
|
||||
}
|
||||
|
||||
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
|
||||
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
|
||||
if configuration.TLS == nil {
|
||||
configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS
|
||||
}
|
||||
|
@ -84,6 +74,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura
|
|||
dialOpts: dialOpts,
|
||||
logger: logging.Logger(),
|
||||
connectionFactory: factory,
|
||||
disableResetPassword: disableResetPassword,
|
||||
}
|
||||
|
||||
provider.parseDynamicUsersConfiguration()
|
||||
|
|
|
@ -4,9 +4,13 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
func (p *LDAPUserProvider) checkServer() (err error) {
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
|
||||
conn, err := p.connect(p.configuration.User, p.configuration.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -29,7 +33,7 @@ func (p *LDAPUserProvider) checkServer() (err error) {
|
|||
// Iterate the attribute values to see what the server supports.
|
||||
for _, attr := range sr.Entries[0].Attributes {
|
||||
if attr.Name == ldapSupportedExtensionAttribute {
|
||||
p.logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
|
||||
logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
|
||||
|
||||
for _, oid := range attr.Values {
|
||||
if oid == ldapOIDPasswdModifyExtension {
|
||||
|
@ -42,6 +46,13 @@ func (p *LDAPUserProvider) checkServer() (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
|
||||
p.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
|
||||
logger.Warn("Your LDAP server implementation may not support a method for password hashing " +
|
||||
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
|
||||
"attribute when users reset their password via Authelia.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -26,6 +27,7 @@ func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) {
|
|||
schema.LDAPAuthenticationBackendConfiguration{
|
||||
URL: "ldap://127.0.0.1:389",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -55,6 +57,7 @@ func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) {
|
|||
schema.LDAPAuthenticationBackendConfiguration{
|
||||
URL: "ldaps://127.0.0.1:389",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -83,6 +86,7 @@ func TestEscapeSpecialCharsFromUserInput(t *testing.T) {
|
|||
schema.LDAPAuthenticationBackendConfiguration{
|
||||
URL: "ldaps://127.0.0.1:389",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -115,6 +119,7 @@ func TestEscapeSpecialCharsInGroupsFilter(t *testing.T) {
|
|||
URL: "ldaps://127.0.0.1:389",
|
||||
GroupsFilter: "(|(member={dn})(uid={username})(uid={input}))",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -179,6 +184,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -210,7 +216,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -235,6 +241,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -266,7 +273,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -291,6 +298,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -298,7 +306,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
|
|||
DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()).
|
||||
Return(mockConn, errors.New("could not connect"))
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
assert.EqualError(t, err, "could not connect")
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -323,6 +331,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -342,7 +351,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
assert.EqualError(t, err, "could not perform the search")
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -384,6 +393,7 @@ func TestShouldEscapeUserInput(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -416,6 +426,7 @@ func TestShouldCombineUsernameFilterAndUsersFilter(t *testing.T) {
|
|||
MailAttribute: "mail",
|
||||
DisplayNameAttribute: "displayName",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -463,6 +474,7 @@ func TestShouldNotCrashWhenGroupsAreNotRetrievedFromLDAP(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -532,6 +544,7 @@ func TestShouldNotCrashWhenEmailsAreNotRetrievedFromLDAP(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -594,6 +607,7 @@ func TestShouldReturnUsernameFromLDAP(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -665,6 +679,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -740,7 +755,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
@ -767,6 +782,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -846,7 +862,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
@ -873,6 +889,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -949,7 +966,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.checkServer()
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
@ -975,6 +992,7 @@ func TestShouldCheckValidUserPassword(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -1042,6 +1060,7 @@ func TestShouldCheckInvalidUserPassword(t *testing.T) {
|
|||
AdditionalUsersDN: "ou=users",
|
||||
BaseDN: "dc=example,dc=com",
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -1110,6 +1129,7 @@ func TestShouldCallStartTLSWhenEnabled(t *testing.T) {
|
|||
BaseDN: "dc=example,dc=com",
|
||||
StartTLS: true,
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -1186,6 +1206,7 @@ func TestShouldParseDynamicConfiguration(t *testing.T) {
|
|||
BaseDN: "dc=example,dc=com",
|
||||
StartTLS: true,
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -1224,6 +1245,7 @@ func TestShouldCallStartTLSWithInsecureSkipVerifyWhenSkipVerifyTrue(t *testing.T
|
|||
SkipVerify: true,
|
||||
},
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
@ -1306,6 +1328,7 @@ func TestShouldReturnLDAPSAlreadySecuredWhenStartTLSAttempted(t *testing.T) {
|
|||
SkipVerify: true,
|
||||
},
|
||||
},
|
||||
false,
|
||||
nil,
|
||||
mockFactory)
|
||||
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package authentication
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// UserProvider is the interface for checking user password and
|
||||
// gathering user details.
|
||||
type UserProvider interface {
|
||||
CheckUserPassword(username string, password string) (bool, error)
|
||||
GetDetails(username string) (*UserDetails, error)
|
||||
UpdatePassword(username string, newPassword string) error
|
||||
CheckUserPassword(username string, password string) (valid bool, err error)
|
||||
GetDetails(username string) (details *UserDetails, err error)
|
||||
UpdatePassword(username string, newPassword string) (err error)
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@ package commands
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
|
@ -78,13 +80,13 @@ func cmdRootRun(_ *cobra.Command, _ []string) {
|
|||
logger.Fatalf("Errors occurred provisioning providers.")
|
||||
}
|
||||
|
||||
doStartupChecks(config, &providers)
|
||||
|
||||
server.Start(*config, providers)
|
||||
}
|
||||
|
||||
//nolint:gocyclo // TODO: Consider refactoring time permitting.
|
||||
func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) {
|
||||
logger := logging.Logger()
|
||||
|
||||
// TODO: Adjust this so the CertPool can be used like a provider.
|
||||
autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory)
|
||||
if len(warnings) != 0 || len(errors) != 0 {
|
||||
return providers, warnings, errors
|
||||
|
@ -100,6 +102,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
case config.Storage.Local != nil:
|
||||
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
|
||||
default:
|
||||
// TODO: Add storage provider startup check and remove this.
|
||||
errors = append(errors, fmt.Errorf("unrecognized storage provider"))
|
||||
}
|
||||
|
||||
|
@ -112,12 +115,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
case config.AuthenticationBackend.File != nil:
|
||||
userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
|
||||
case config.AuthenticationBackend.LDAP != nil:
|
||||
userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err))
|
||||
}
|
||||
default:
|
||||
errors = append(errors, fmt.Errorf("unrecognized user provider"))
|
||||
userProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
|
||||
}
|
||||
|
||||
var notifier notification.Notifier
|
||||
|
@ -127,14 +125,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool)
|
||||
case config.Notifier.FileSystem != nil:
|
||||
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
|
||||
default:
|
||||
errors = append(errors, fmt.Errorf("unrecognized notifier provider"))
|
||||
}
|
||||
|
||||
if notifier != nil && !config.Notifier.DisableStartupCheck {
|
||||
if _, err := notifier.StartupCheck(); err != nil {
|
||||
errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
var ntpProvider *ntp.Provider
|
||||
|
@ -152,25 +142,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
var failed bool
|
||||
if !config.NTP.DisableStartupCheck && authorizer.IsSecondFactorEnabled() {
|
||||
failed, err = ntpProvider.StartupCheck()
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to check time against the NTP server: %+v", err)
|
||||
}
|
||||
|
||||
if failed {
|
||||
if config.NTP.DisableFailure {
|
||||
logger.Error("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server, this may cause issues in validating TOTP secrets")
|
||||
} else {
|
||||
logger.Fatal("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server")
|
||||
}
|
||||
} else {
|
||||
logger.Debug("The system time is within the maximum desynchronization when compared to the time reported by the NTP server")
|
||||
}
|
||||
}
|
||||
|
||||
return middlewares.Providers{
|
||||
Authorizer: authorizer,
|
||||
UserProvider: userProvider,
|
||||
|
@ -182,3 +153,53 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
SessionProvider: sessionProvider,
|
||||
}, warnings, errors
|
||||
}
|
||||
|
||||
func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) {
|
||||
logger := logging.Logger()
|
||||
|
||||
var (
|
||||
failures []string
|
||||
err error
|
||||
)
|
||||
|
||||
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
|
||||
logger.Errorf("Failure running the user provider startup check: %+v", err)
|
||||
|
||||
failures = append(failures, "user")
|
||||
}
|
||||
|
||||
if err = doStartupCheck(logger, "notification", providers.Notifier, config.Notifier.DisableStartupCheck); err != nil {
|
||||
logger.Errorf("Failure running the notification provider startup check: %+v", err)
|
||||
|
||||
failures = append(failures, "notification")
|
||||
}
|
||||
|
||||
if !config.NTP.DisableStartupCheck && !providers.Authorizer.IsSecondFactorEnabled() {
|
||||
logger.Debug("The NTP startup check was skipped due to there being no configured 2FA access control rules")
|
||||
} else if err = doStartupCheck(logger, "ntp", providers.NTP, config.NTP.DisableStartupCheck); err != nil {
|
||||
logger.Errorf("Failure running the user provider startup check: %+v", err)
|
||||
|
||||
failures = append(failures, "ntp")
|
||||
}
|
||||
|
||||
if len(failures) != 0 {
|
||||
logger.Fatalf("The following providers had fatal failures during startup: %s", strings.Join(failures, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
|
||||
if disabled {
|
||||
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
if provider == nil {
|
||||
return fmt.Errorf("unrecognized provider or it is not configured properly")
|
||||
}
|
||||
|
||||
if err = provider.StartupCheck(logger); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -28,6 +28,11 @@ type AutheliaCtx struct {
|
|||
Clock utils.Clock
|
||||
}
|
||||
|
||||
// ProviderWithStartupCheck represents a provider that has a startup check.
|
||||
type ProviderWithStartupCheck interface {
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
||||
// Providers contain all provider provided to Authelia.
|
||||
type Providers struct {
|
||||
Authorizer *authorization.Authorizer
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// MockNotifier is a mock of Notifier interface.
|
||||
|
@ -48,16 +49,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go
|
|||
}
|
||||
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockNotifier) StartupCheck() (bool, error) {
|
||||
func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck")
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret := m.ctrl.Call(m, "StartupCheck", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call {
|
||||
func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"reflect"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
)
|
||||
|
@ -78,3 +79,17 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) *
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
|
@ -22,28 +24,28 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File
|
|||
}
|
||||
}
|
||||
|
||||
// StartupCheck checks the file provider can write to the specified file.
|
||||
func (n *FileNotifier) StartupCheck() (bool, error) {
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
dir := filepath.Dir(n.path)
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(dir, fileNotifierMode); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
} else if _, err = os.Stat(n.path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send send a identity verification link to a user.
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
package notification
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Notifier interface for sending the identity verification link.
|
||||
type Notifier interface {
|
||||
Send(recipient, subject, body, htmlBody string) error
|
||||
StartupCheck() (bool, error)
|
||||
Send(recipient, subject, body, htmlBody string) (err error)
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
|
@ -220,39 +222,39 @@ func (n *SMTPNotifier) cleanup() {
|
|||
}
|
||||
}
|
||||
|
||||
// StartupCheck checks the server is functioning correctly and the configuration is correct.
|
||||
func (n *SMTPNotifier) StartupCheck() (bool, error) {
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
if err := n.dial(); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
defer n.cleanup()
|
||||
|
||||
if err := n.client.Hello(n.configuration.Identifier); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := n.startTLS(); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := n.auth(); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := n.client.Mail(n.configuration.Sender); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := n.client.Rcpt(n.configuration.StartupCheckAddress); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := n.client.Reset(); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send is used to send an email to a recipient.
|
||||
|
|
|
@ -2,10 +2,12 @@ package ntp
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
@ -15,17 +17,21 @@ func NewProvider(config *schema.NTPConfiguration) *Provider {
|
|||
return &Provider{config}
|
||||
}
|
||||
|
||||
// StartupCheck checks if the system clock is not out of sync.
|
||||
func (p *Provider) StartupCheck() (failed bool, err error) {
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
|
||||
conn, err := net.Dial("udp", p.config.Address)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err)
|
||||
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err)
|
||||
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
version := ntpV4
|
||||
|
@ -36,7 +42,9 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
|
|||
req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)}
|
||||
|
||||
if err := binary.Write(conn, binary.BigEndian, req); err != nil {
|
||||
return false, fmt.Errorf("could not write to the NTP server socket to validate the time desync: %w", err)
|
||||
logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
@ -44,12 +52,18 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
|
|||
resp := &ntpPacket{}
|
||||
|
||||
if err := binary.Read(conn, binary.BigEndian, resp); err != nil {
|
||||
return false, fmt.Errorf("could not read from the NTP server socket to validate the time desync: %w", err)
|
||||
logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
|
||||
|
||||
ntpTime := ntpPacketToTime(resp)
|
||||
|
||||
return ntpIsOffsetTooLarge(maxOffset, now, ntpTime), nil
|
||||
if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result {
|
||||
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
func TestShouldCheckNTP(t *testing.T) {
|
||||
|
@ -19,8 +20,7 @@ func TestShouldCheckNTP(t *testing.T) {
|
|||
sv := schema.NewStructValidator()
|
||||
validator.ValidateNTP(&config, sv)
|
||||
|
||||
NTP := NewProvider(&config)
|
||||
ntp := NewProvider(&config)
|
||||
|
||||
checkfailed, _ := NTP.StartupCheck()
|
||||
assert.Equal(t, false, checkfailed)
|
||||
assert.NoError(t, ntp.StartupCheck(logging.Logger()))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue