diff --git a/internal/configuration/schema/storage.go b/internal/configuration/schema/storage.go index f806fd3ef..3503689db 100644 --- a/internal/configuration/schema/storage.go +++ b/internal/configuration/schema/storage.go @@ -54,3 +54,11 @@ type StorageConfiguration struct { var DefaultSQLStorageConfiguration = SQLStorageConfiguration{ Timeout: 5 * time.Second, } + +// DefaultPostgreSQLStorageConfiguration represents the default PostgreSQL configuration. +var DefaultPostgreSQLStorageConfiguration = PostgreSQLStorageConfiguration{ + Schema: "public", + SSL: PostgreSQLSSLStorageConfiguration{ + Mode: "disable", + }, +} diff --git a/internal/configuration/validator/storage.go b/internal/configuration/validator/storage.go index 0a2c69f8b..bbed3e6ae 100644 --- a/internal/configuration/validator/storage.go +++ b/internal/configuration/validator/storage.go @@ -52,13 +52,17 @@ func validateSQLConfiguration(configuration *schema.SQLStorageConfiguration, val func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) { validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres") + if configuration.Schema == "" { + configuration.Schema = schema.DefaultPostgreSQLStorageConfiguration.Schema + } + // Deprecated. TODO: Remove in v4.36.0. if configuration.SSLMode != "" && configuration.SSL.Mode == "" { configuration.SSL.Mode = configuration.SSLMode } if configuration.SSL.Mode == "" { - configuration.SSL.Mode = testModeDisabled + configuration.SSL.Mode = schema.DefaultPostgreSQLStorageConfiguration.SSL.Mode } else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) { validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '"))) } diff --git a/internal/configuration/validator/storage_test.go b/internal/configuration/validator/storage_test.go index b1aa888c3..78bbc27db 100644 --- a/internal/configuration/validator/storage_test.go +++ b/internal/configuration/validator/storage_test.go @@ -104,7 +104,7 @@ func (suite *StorageSuite) TestShouldValidatePostgreSQLHostUsernamePasswordAndDa suite.Assert().Len(suite.validator.Errors(), 0) } -func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault() { +func (suite *StorageSuite) TestShouldValidatePostgresSSLModeAndSchemaDefaults() { suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{ SQLStorageConfiguration: schema.SQLStorageConfiguration{ Host: "db1", @@ -120,6 +120,30 @@ func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault() suite.Assert().Len(suite.validator.Errors(), 0) suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode) + suite.Assert().Equal("public", suite.configuration.PostgreSQL.Schema) +} + +func (suite *StorageSuite) TestShouldValidatePostgresDefaultsDontOverrideConfiguration() { + suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{ + SQLStorageConfiguration: schema.SQLStorageConfiguration{ + Host: "db1", + Username: "myuser", + Password: "pass", + Database: "database", + }, + Schema: "authelia", + SSL: schema.PostgreSQLSSLStorageConfiguration{ + Mode: "require", + }, + } + + ValidateStorage(suite.configuration, suite.validator) + + suite.Assert().Len(suite.validator.Warnings(), 0) + suite.Assert().Len(suite.validator.Errors(), 0) + + suite.Assert().Equal("require", suite.configuration.PostgreSQL.SSL.Mode) + suite.Assert().Equal("authelia", suite.configuration.PostgreSQL.Schema) } func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() { diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 5bb89cec3..4d4b71ad4 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -79,6 +79,7 @@ type SQLProvider struct { key [32]byte name string driverName string + schema string config *schema.Configuration errOpen error diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 07868e0e6..e70718dd7 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -57,6 +57,8 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration) provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue) + provider.schema = config.Storage.PostgreSQL.Schema + return provider } @@ -66,20 +68,14 @@ func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dat fmt.Sprintf("user='%s'", config.Username), fmt.Sprintf("password='%s'", config.Password), fmt.Sprintf("dbname=%s", config.Database), + fmt.Sprintf("search_path=%s", config.Schema), + fmt.Sprintf("sslmode=%s", config.SSL.Mode), } if config.Port > 0 { args = append(args, fmt.Sprintf("port=%d", config.Port)) } - if config.Schema != "" { - args = append(args, fmt.Sprintf("search_path=%s", config.Schema)) - } - - if config.SSL.Mode != "" { - args = append(args, fmt.Sprintf("sslmode=%s", config.SSL.Mode)) - } - if config.SSL.RootCertificate != "" { args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate)) } diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index 5ffbb1cf8..1cb9de960 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -25,7 +25,7 @@ const ( queryPostgreSelectExistingTables = ` SELECT table_name FROM information_schema.tables - WHERE table_type = 'BASE TABLE' AND table_schema = 'public';` + WHERE table_type = 'BASE TABLE' AND table_schema = $1;` querySQLiteSelectExistingTables = ` SELECT name diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index 9597e07bb..225f57321 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -7,13 +7,23 @@ import ( "strconv" "time" + "github.com/jmoiron/sqlx" + "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/utils" ) // SchemaTables returns a list of tables. func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) { - rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables) + var rows *sqlx.Rows + + switch p.schema { + case "": + rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables) + default: + rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables, p.schema) + } + if err != nil { return tables, err }