fix(storage): postgres schema hardcoded for tables query (#2667)

This removes the hardcoded schema value from the PostgreSQL existing tables query, making it compatible with the new schema config option.
pull/2665/head^2
James Elliott 2021-12-03 17:29:55 +11:00 committed by GitHub
parent ec1cc3d64e
commit 95a5e326a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 12 deletions

View File

@ -54,3 +54,11 @@ type StorageConfiguration struct {
var DefaultSQLStorageConfiguration = SQLStorageConfiguration{ var DefaultSQLStorageConfiguration = SQLStorageConfiguration{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
// DefaultPostgreSQLStorageConfiguration represents the default PostgreSQL configuration.
var DefaultPostgreSQLStorageConfiguration = PostgreSQLStorageConfiguration{
Schema: "public",
SSL: PostgreSQLSSLStorageConfiguration{
Mode: "disable",
},
}

View File

@ -52,13 +52,17 @@ func validateSQLConfiguration(configuration *schema.SQLStorageConfiguration, val
func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) { func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) {
validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres") validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres")
if configuration.Schema == "" {
configuration.Schema = schema.DefaultPostgreSQLStorageConfiguration.Schema
}
// Deprecated. TODO: Remove in v4.36.0. // Deprecated. TODO: Remove in v4.36.0.
if configuration.SSLMode != "" && configuration.SSL.Mode == "" { if configuration.SSLMode != "" && configuration.SSL.Mode == "" {
configuration.SSL.Mode = configuration.SSLMode configuration.SSL.Mode = configuration.SSLMode
} }
if configuration.SSL.Mode == "" { if configuration.SSL.Mode == "" {
configuration.SSL.Mode = testModeDisabled configuration.SSL.Mode = schema.DefaultPostgreSQLStorageConfiguration.SSL.Mode
} else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) { } else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) {
validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '"))) validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '")))
} }

View File

@ -104,7 +104,7 @@ func (suite *StorageSuite) TestShouldValidatePostgreSQLHostUsernamePasswordAndDa
suite.Assert().Len(suite.validator.Errors(), 0) suite.Assert().Len(suite.validator.Errors(), 0)
} }
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault() { func (suite *StorageSuite) TestShouldValidatePostgresSSLModeAndSchemaDefaults() {
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{ suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
SQLStorageConfiguration: schema.SQLStorageConfiguration{ SQLStorageConfiguration: schema.SQLStorageConfiguration{
Host: "db1", Host: "db1",
@ -120,6 +120,30 @@ func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault()
suite.Assert().Len(suite.validator.Errors(), 0) suite.Assert().Len(suite.validator.Errors(), 0)
suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode) suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode)
suite.Assert().Equal("public", suite.configuration.PostgreSQL.Schema)
}
func (suite *StorageSuite) TestShouldValidatePostgresDefaultsDontOverrideConfiguration() {
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
SQLStorageConfiguration: schema.SQLStorageConfiguration{
Host: "db1",
Username: "myuser",
Password: "pass",
Database: "database",
},
Schema: "authelia",
SSL: schema.PostgreSQLSSLStorageConfiguration{
Mode: "require",
},
}
ValidateStorage(suite.configuration, suite.validator)
suite.Assert().Len(suite.validator.Warnings(), 0)
suite.Assert().Len(suite.validator.Errors(), 0)
suite.Assert().Equal("require", suite.configuration.PostgreSQL.SSL.Mode)
suite.Assert().Equal("authelia", suite.configuration.PostgreSQL.Schema)
} }
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() { func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() {

View File

@ -79,6 +79,7 @@ type SQLProvider struct {
key [32]byte key [32]byte
name string name string
driverName string driverName string
schema string
config *schema.Configuration config *schema.Configuration
errOpen error errOpen error

View File

@ -57,6 +57,8 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration) provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue) provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
provider.schema = config.Storage.PostgreSQL.Schema
return provider return provider
} }
@ -66,20 +68,14 @@ func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dat
fmt.Sprintf("user='%s'", config.Username), fmt.Sprintf("user='%s'", config.Username),
fmt.Sprintf("password='%s'", config.Password), fmt.Sprintf("password='%s'", config.Password),
fmt.Sprintf("dbname=%s", config.Database), fmt.Sprintf("dbname=%s", config.Database),
fmt.Sprintf("search_path=%s", config.Schema),
fmt.Sprintf("sslmode=%s", config.SSL.Mode),
} }
if config.Port > 0 { if config.Port > 0 {
args = append(args, fmt.Sprintf("port=%d", config.Port)) args = append(args, fmt.Sprintf("port=%d", config.Port))
} }
if config.Schema != "" {
args = append(args, fmt.Sprintf("search_path=%s", config.Schema))
}
if config.SSL.Mode != "" {
args = append(args, fmt.Sprintf("sslmode=%s", config.SSL.Mode))
}
if config.SSL.RootCertificate != "" { if config.SSL.RootCertificate != "" {
args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate)) args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
} }

View File

@ -25,7 +25,7 @@ const (
queryPostgreSelectExistingTables = ` queryPostgreSelectExistingTables = `
SELECT table_name SELECT table_name
FROM information_schema.tables FROM information_schema.tables
WHERE table_type = 'BASE TABLE' AND table_schema = 'public';` WHERE table_type = 'BASE TABLE' AND table_schema = $1;`
querySQLiteSelectExistingTables = ` querySQLiteSelectExistingTables = `
SELECT name SELECT name

View File

@ -7,13 +7,23 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/jmoiron/sqlx"
"github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// SchemaTables returns a list of tables. // SchemaTables returns a list of tables.
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) { func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables) var rows *sqlx.Rows
switch p.schema {
case "":
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
default:
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables, p.schema)
}
if err != nil { if err != nil {
return tables, err return tables, err
} }