From e9a383be0c57d36031b81e30be93b49e7d221cdc Mon Sep 17 00:00:00 2001 From: RPJosh Date: Fri, 23 Jun 2023 22:07:23 +0200 Subject: [PATCH] Add option to ban user by ip instead of username --- MyNotes.md | 9 +++++ config.template.yml | 2 + internal/configuration/schema/keys.go | 1 + internal/configuration/schema/server.go | 1 + internal/handlers/handler_firstfactor.go | 2 +- internal/handlers/response.go | 11 +++-- internal/mocks/storage.go | 8 ++-- internal/regulation/regulator.go | 4 +- internal/regulation/regulator_test.go | 40 ++++++++++--------- internal/storage/provider.go | 2 +- internal/storage/sql_provider.go | 15 ++++++- .../storage/sql_provider_backend_postgres.go | 1 + internal/storage/sql_provider_queries.go | 8 ++++ 13 files changed, 72 insertions(+), 32 deletions(-) diff --git a/MyNotes.md b/MyNotes.md index f02b9ec6c..d48e4aceb 100644 --- a/MyNotes.md +++ b/MyNotes.md @@ -58,6 +58,15 @@ Wenn die Konfiguration geändert wurde, müssen die Keys zur Validierung wieder go run ./cmd/authelia-gen code keys ``` +## Mocks abgeändert + +Wenn interfaces von den Mocks geändert werden, muss folgendes wieder ausgeführt werden: + +``` +export PATH=$PATH:$(go env GOPATH)/bin +go generate ./... +``` + ## Bauen Um ein Docker Image für authelia zu bauen, müssen die folgenden Befehle ausgeführt werden. diff --git a/config.template.yml b/config.template.yml index 75a5e4f6f..e58d21b53 100644 --- a/config.template.yml +++ b/config.template.yml @@ -85,6 +85,8 @@ server: # Even if TLS is configured in the server setting (under server.tls), the grcp server won't use TLS disableTLS: false + # By default the ban is issued for the user. With this options the IP instead of the user will be banned + use_ip_for_ban: true ## Server headers configuration/customization. headers: diff --git a/internal/configuration/schema/keys.go b/internal/configuration/schema/keys.go index 259bc6568..f1b013c08 100644 --- a/internal/configuration/schema/keys.go +++ b/internal/configuration/schema/keys.go @@ -265,6 +265,7 @@ var Keys = []string{ "server.asset_path", "server.disable_healthcheck", "server.disable_autho_https_redirect", + "server.use_ip_for_ban", "server.tls.certificate", "server.tls.key", "server.tls.client_certificates", diff --git a/internal/configuration/schema/server.go b/internal/configuration/schema/server.go index 9bdce09ae..8290221fd 100644 --- a/internal/configuration/schema/server.go +++ b/internal/configuration/schema/server.go @@ -11,6 +11,7 @@ type ServerConfiguration struct { AssetPath string `koanf:"asset_path"` DisableHealthcheck bool `koanf:"disable_healthcheck"` DisableAutoHttpsRedirect bool `koanf:"disable_autho_https_redirect"` + UseIPInsteadOfUserForBan bool `koanf:"use_ip_for_ban"` TLS ServerTLS `koanf:"tls"` Headers ServerHeaders `koanf:"headers"` diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go index 75aea5e02..ba533ad3b 100644 --- a/internal/handlers/handler_firstfactor.go +++ b/internal/handlers/handler_firstfactor.go @@ -33,7 +33,7 @@ func FirstFactorPOST(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re return } - if bannedUntil, err := ctx.Providers.Regulator.Regulate(ctx, bodyJSON.Username); err != nil { + if bannedUntil, err := ctx.Providers.Regulator.Regulate(ctx, bodyJSON.Username, ctx.RemoteIP().String()); err != nil { if errors.Is(err, regulation.ErrUserIsBanned) { _ = markAuthenticationAttempt(ctx, false, &bannedUntil, bodyJSON.Username, regulation.AuthType1FA, nil) diff --git a/internal/handlers/response.go b/internal/handlers/response.go index 9b8bda53f..2d5ce0d9e 100644 --- a/internal/handlers/response.go +++ b/internal/handlers/response.go @@ -277,13 +277,18 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba if successful { ctx.Logger.Debugf("Successful %s authentication attempt made by user '%s'", authType, username) } else { + reasonPhrase := "by user '" + username + "'" + if ctx.Configuration.Server.UseIPInsteadOfUserForBan { + reasonPhrase = fmt.Sprintf("by ip %q (user %q)", ctx.RemoteIP().String(), username) + } + switch { case errAuth != nil: - ctx.Logger.Errorf("Unsuccessful %s authentication attempt by user '%s': %+v", authType, username, errAuth) + ctx.Logger.Errorf("Unsuccessful %s authentication attempt %s: %+v", authType, reasonPhrase, errAuth) case bannedUntil != nil: - ctx.Logger.Errorf("Unsuccessful %s authentication attempt by user '%s' and they are banned until %s", authType, username, bannedUntil) + ctx.Logger.Errorf("Unsuccessful %s authentication attempt %s and they are banned until %s", authType, reasonPhrase, bannedUntil) default: - ctx.Logger.Errorf("Unsuccessful %s authentication attempt by user '%s'", authType, username) + ctx.Logger.Errorf("Unsuccessful %s authentication attempt %s", authType, reasonPhrase) } } diff --git a/internal/mocks/storage.go b/internal/mocks/storage.go index c038ae7b7..71966681e 100644 --- a/internal/mocks/storage.go +++ b/internal/mocks/storage.go @@ -210,18 +210,18 @@ func (mr *MockStorageMockRecorder) FindIdentityVerification(arg0, arg1 interface } // LoadAuthenticationLogs mocks base method. -func (m *MockStorage) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]model.AuthenticationAttempt, error) { +func (m *MockStorage) LoadAuthenticationLogs(arg0 context.Context, arg1, arg2 string, arg3 time.Time, arg4, arg5 int) ([]model.AuthenticationAttempt, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4, arg5) ret0, _ := ret[0].([]model.AuthenticationAttempt) ret1, _ := ret[1].(error) return ret0, ret1 } // LoadAuthenticationLogs indicates an expected call of LoadAuthenticationLogs. -func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4, arg5) } // LoadOAuth2BlacklistedJTI mocks base method. diff --git a/internal/regulation/regulator.go b/internal/regulation/regulator.go index ab1995483..c3b376cc2 100644 --- a/internal/regulation/regulator.go +++ b/internal/regulation/regulator.go @@ -40,13 +40,13 @@ func (r *Regulator) Mark(ctx Context, successful, banned bool, username, request // Regulate the authentication attempts for a given user. // This method returns ErrUserIsBanned if the user is banned along with the time until when the user is banned. -func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, error) { +func (r *Regulator) Regulate(ctx context.Context, username string, remoteIp string) (time.Time, error) { // If there is regulation configuration, no regulation applies. if !r.enabled { return time.Time{}, nil } - attempts, err := r.store.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0) + attempts, err := r.store.LoadAuthenticationLogs(ctx, username, remoteIp, r.clock.Now().Add(-r.config.BanTime), 10, 0) if err != nil { return time.Time{}, nil } diff --git a/internal/regulation/regulator_test.go b/internal/regulation/regulator_test.go index a7156dbd0..a9d7e53da 100644 --- a/internal/regulation/regulator_test.go +++ b/internal/regulation/regulator_test.go @@ -23,6 +23,8 @@ type RegulatorSuite struct { mock *mocks.MockAutheliaCtx } +// @TODO +// Extend this test for IP ban :) func (s *RegulatorSuite) SetupTest() { s.mock = mocks.NewMockAutheliaCtx(s.T()) s.mock.Ctx.Configuration.Regulation = schema.RegulationConfiguration{ @@ -58,9 +60,9 @@ func (s *RegulatorSuite) TestShouldMark() { func (s *RegulatorSuite) TestShouldHandleRegulateError() { regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed")) + s.mock.StorageMock.EXPECT().LoadAuthenticationLogs(s.mock.Ctx, "john", "127.0.0.1", s.mock.Clock.Now().Add(-s.mock.Ctx.Configuration.Regulation.BanTime), 10, 0).Return(nil, fmt.Errorf("failed")) - until, err := regulator.Regulate(s.mock.Ctx, "john") + until, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") s.NoError(err) s.Equal(time.Time{}, until) @@ -76,12 +78,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) } @@ -107,12 +109,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) } @@ -143,12 +145,12 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -176,12 +178,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -200,12 +202,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) } @@ -232,12 +234,12 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) } @@ -268,12 +270,12 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(s.mock.Ctx.Configuration.Regulation, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) } @@ -303,7 +305,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } s.mock.StorageMock.EXPECT(). - LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). + LoadAuthenticationLogs(s.mock.Ctx, gomock.Eq("john"), "127.0.0.1", gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) // Check Disabled Functionality. @@ -314,7 +316,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } regulator := regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock) - _, err := regulator.Regulate(s.mock.Ctx, "john") + _, err := regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.NoError(s.T(), err) // Check Enabled Functionality. @@ -325,6 +327,6 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } regulator = regulation.NewRegulator(config, s.mock.StorageMock, &s.mock.Clock) - _, err = regulator.Regulate(s.mock.Ctx, "john") + _, err = regulator.Regulate(s.mock.Ctx, "john", "127.0.0.1") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } diff --git a/internal/storage/provider.go b/internal/storage/provider.go index 651cdadfa..73c823476 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -90,5 +90,5 @@ type Provider interface { // RegulatorProvider is an interface providing storage capabilities for persisting any kind of data related to the regulator. type RegulatorProvider interface { AppendAuthenticationLog(ctx context.Context, attempt model.AuthenticationAttempt) (err error) - LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) + LoadAuthenticationLogs(ctx context.Context, username string, ip string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) } diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 0c4d6376e..16490a992 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -33,6 +33,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), + sqlSelectAuthenticationAttemptyByIP: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByIP, tableAuthenticationLogs), sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification), sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), @@ -149,6 +150,7 @@ type SQLProvider struct { // Table: authentication_logs. sqlInsertAuthenticationAttempt string sqlSelectAuthenticationAttemptsByUsername string + sqlSelectAuthenticationAttemptyByIP string // Table: identity_verification. sqlInsertIdentityVerification string @@ -1021,10 +1023,18 @@ func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model } // LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. -func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) { +func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, ip string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) { attempts = make([]model.AuthenticationAttempt, 0, limit) - if err = p.db.SelectContext(ctx, &attempts, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page); err != nil { + // Dynmaic values based on ip / username ban + query := p.sqlSelectAuthenticationAttemptsByUsername + placeholder := username + if p.config.Server.UseIPInsteadOfUserForBan { + query = p.sqlSelectAuthenticationAttemptyByIP + placeholder = ip + } + + if err = p.db.SelectContext(ctx, &attempts, query, fromDate, placeholder, limit, limit*page); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoAuthenticationLogs } @@ -1033,4 +1043,5 @@ func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username strin } return attempts, nil + } diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index ef7054951..6f716db60 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -71,6 +71,7 @@ func NewPostgreSQLProvider(config *schema.Configuration, caCertPool *x509.CertPo provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt) provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername) + provider.sqlSelectAuthenticationAttemptyByIP = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptyByIP) provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration) provider.sqlSelectMigrations = provider.db.Rebind(provider.sqlSelectMigrations) diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index b089e23ce..09a1e043e 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -211,6 +211,14 @@ const ( ORDER BY time DESC LIMIT ? OFFSET ?;` + + queryFmtSelect1FAAuthenticationLogEntryByIP = ` + SELECT time, successful, username + FROM %s + WHERE time > ? AND remote_ip = ? AND auth_type = '1FA' AND banned = FALSE + ORDER BY time DESC + LIMIT ? + OFFSET ?;` ) const (