authelia/internal/storage/sql_provider_backend_postgr...

78 lines
3.6 KiB
Go

package storage
import (
"fmt"
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// PostgreSQLProvider is a PostgreSQL provider.
type PostgreSQLProvider struct {
SQLProvider
}
// NewPostgreSQLProvider a PostgreSQL provider.
func NewPostgreSQLProvider(config schema.PostgreSQLStorageConfiguration, encryptionKey string) (provider *PostgreSQLProvider) {
provider = &PostgreSQLProvider{
SQLProvider: NewSQLProvider(providerPostgres, "pgx", dataSourceNamePostgreSQL(config), encryptionKey),
}
// All providers have differing SELECT existing table statements.
provider.sqlSelectExistingTables = queryPostgreSelectExistingTables
// Specific alterations to this provider.
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
provider.sqlUpsertU2FDevice = fmt.Sprintf(queryFmtPostgresUpsertU2FDevice, tableU2FDevices)
provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations)
provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences)
provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtPostgresUpsertEncryptionValue, tableEncryption)
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpsertTOTPConfig = provider.db.Rebind(provider.sqlUpsertTOTPConfig)
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
return provider
}
func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
args := []string{
fmt.Sprintf("user='%s'", config.Username),
fmt.Sprintf("password='%s'", config.Password),
}
if config.Host != "" {
args = append(args, fmt.Sprintf("host=%s", config.Host))
}
if config.Port > 0 {
args = append(args, fmt.Sprintf("port=%d", config.Port))
}
if config.Database != "" {
args = append(args, fmt.Sprintf("dbname=%s", config.Database))
}
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
return strings.Join(args, " ")
}