382 lines
10 KiB
Go
382 lines
10 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/authelia/authelia/v4/internal/model"
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// SchemaTables returns a list of tables.
|
|
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
|
|
var rows *sqlx.Rows
|
|
|
|
switch p.schema {
|
|
case "":
|
|
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
|
default:
|
|
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables, p.schema)
|
|
}
|
|
|
|
if err != nil {
|
|
return tables, err
|
|
}
|
|
|
|
defer func() {
|
|
if err := rows.Close(); err != nil {
|
|
p.log.Errorf(logFmtErrClosingConn, err)
|
|
}
|
|
}()
|
|
|
|
var table string
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&table)
|
|
if err != nil {
|
|
return []string{}, err
|
|
}
|
|
|
|
tables = append(tables, table)
|
|
}
|
|
|
|
return tables, nil
|
|
}
|
|
|
|
// SchemaVersion returns the version of the schema.
|
|
func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error) {
|
|
tables, err := p.SchemaTables(ctx)
|
|
if err != nil {
|
|
return -2, err
|
|
}
|
|
|
|
if len(tables) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
if utils.IsStringInSlice(tableMigrations, tables) {
|
|
migration, err := p.schemaLatestMigration(ctx)
|
|
if err != nil {
|
|
return -2, err
|
|
}
|
|
|
|
return migration.After, nil
|
|
}
|
|
|
|
var tablesV1 = []string{tableDuoDevices, tableEncryption, tableIdentityVerification, tableMigrations, tableTOTPConfigurations}
|
|
|
|
if utils.IsStringSliceContainsAll(tablesPre1, tables) {
|
|
if utils.IsStringSliceContainsAny(tablesV1, tables) {
|
|
return -2, errors.New("pre1 schema contains v1 tables it shouldn't contain")
|
|
}
|
|
|
|
return -1, nil
|
|
}
|
|
|
|
return 0, nil
|
|
}
|
|
|
|
// SchemaLatestVersion returns the latest version available for migration.
|
|
func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
|
|
return latestMigrationVersion(p.name)
|
|
}
|
|
|
|
// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
|
|
func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
|
|
current, err := p.SchemaVersion(ctx)
|
|
if err != nil {
|
|
return migrations, err
|
|
}
|
|
|
|
if version == 0 {
|
|
version = SchemaLatest
|
|
}
|
|
|
|
if current >= version {
|
|
return migrations, ErrNoAvailableMigrations
|
|
}
|
|
|
|
return loadMigrations(p.name, current, version)
|
|
}
|
|
|
|
// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
|
|
func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []model.SchemaMigration, err error) {
|
|
current, err := p.SchemaVersion(ctx)
|
|
if err != nil {
|
|
return migrations, err
|
|
}
|
|
|
|
if current <= version {
|
|
return migrations, ErrNoAvailableMigrations
|
|
}
|
|
|
|
return loadMigrations(p.name, current, version)
|
|
}
|
|
|
|
// SchemaMigrationHistory returns migration history rows.
|
|
func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []model.Migration, err error) {
|
|
rows, err := p.db.QueryxContext(ctx, p.sqlSelectMigrations)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer func() {
|
|
if err := rows.Close(); err != nil {
|
|
p.log.Errorf(logFmtErrClosingConn, err)
|
|
}
|
|
}()
|
|
|
|
var migration model.Migration
|
|
|
|
for rows.Next() {
|
|
err = rows.StructScan(&migration)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migrations = append(migrations, migration)
|
|
}
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
// SchemaMigrate migrates from the current version to the provided version.
|
|
func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) {
|
|
var (
|
|
tx *sqlx.Tx
|
|
conn SQLXConnection
|
|
)
|
|
|
|
if p.name != providerMySQL {
|
|
if tx, err = p.db.BeginTxx(ctx, nil); err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
conn = tx
|
|
} else {
|
|
conn = p.db
|
|
}
|
|
|
|
currentVersion, err := p.SchemaVersion(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if currentVersion != 0 {
|
|
if err = p.schemaMigrateLock(ctx, conn); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
if err = p.schemaMigrate(ctx, conn, currentVersion, version); err != nil {
|
|
if tx != nil && err == ErrNoMigrationsFound {
|
|
_ = tx.Rollback()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
if tx != nil {
|
|
if err = tx.Commit(); err != nil {
|
|
if rerr := tx.Rollback(); rerr != nil {
|
|
return fmt.Errorf("failed to commit the transaction with: commit error: %w, rollback error: %+v", err, rerr)
|
|
}
|
|
|
|
return fmt.Errorf("failed to commit the transaction but it has been rolled back: commit error: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrate(ctx context.Context, conn SQLXConnection, prior, target int) (err error) {
|
|
migrations, err := loadMigrations(p.name, prior, target)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(migrations) == 0 {
|
|
return ErrNoMigrationsFound
|
|
}
|
|
|
|
p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
|
|
|
|
for i, migration := range migrations {
|
|
if migration.Up && prior == 0 && i == 1 {
|
|
if err = p.schemaMigrateLock(ctx, conn); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err = p.schemaMigrateApply(ctx, conn, migration); err != nil {
|
|
return p.schemaMigrateRollback(ctx, conn, prior, migration.After(), err)
|
|
}
|
|
}
|
|
|
|
p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection) (err error) {
|
|
if p.name != providerPostgres {
|
|
return nil
|
|
}
|
|
|
|
if _, err = conn.ExecContext(ctx, fmt.Sprintf(queryFmtPostgreSQLLockTable, tableMigrations, "ACCESS EXCLUSIVE")); err != nil {
|
|
return fmt.Errorf("failed to lock tables: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
|
|
if migration.NotEmpty() {
|
|
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
|
|
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
|
}
|
|
|
|
if migration.Version == 1 && migration.Up {
|
|
// Add the schema encryption value if upgrading to v1.
|
|
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
|
|
if migration.Version == 1 && !migration.Up {
|
|
return nil
|
|
}
|
|
|
|
if _, err = conn.ExecContext(ctx, p.sqlInsertMigration, time.Now(), migration.Before(), migration.After(), utils.Version()); err != nil {
|
|
return fmt.Errorf("failed inserting migration record: %w", err)
|
|
}
|
|
|
|
p.log.Debugf("Storage schema migrated from version %d to %d", migration.Before(), migration.After())
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, conn SQLXConnection, prior, after int, merr error) (err error) {
|
|
switch tx := conn.(type) {
|
|
case *sqlx.Tx:
|
|
return p.schemaMigrateRollbackWithTx(ctx, tx, merr)
|
|
default:
|
|
return p.schemaMigrateRollbackWithoutTx(ctx, prior, after, merr)
|
|
}
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateRollbackWithTx(_ context.Context, tx *sqlx.Tx, merr error) (err error) {
|
|
if err = tx.Rollback(); err != nil {
|
|
return fmt.Errorf("error applying rollback %+v. rollback caused by: %w", err, merr)
|
|
}
|
|
|
|
return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
|
|
}
|
|
|
|
func (p *SQLProvider) schemaMigrateRollbackWithoutTx(ctx context.Context, prior, after int, merr error) (err error) {
|
|
migrations, err := loadMigrations(p.name, after, prior)
|
|
if err != nil {
|
|
return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %w", prior, after, err, merr)
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
if err = p.schemaMigrateApply(ctx, p.db, migration); err != nil {
|
|
return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %w", migration.Before(), migration.After(), err, merr)
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("migration rollback complete. rollback caused by: %w", merr)
|
|
}
|
|
|
|
func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *model.Migration, err error) {
|
|
migration = &model.Migration{}
|
|
|
|
if err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return migration, nil
|
|
}
|
|
|
|
func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) {
|
|
switch {
|
|
case currentVersion == -1:
|
|
return fmt.Errorf(errFmtMigrationPre1, "up from", errFmtMigrationPre1SuggestedVersion)
|
|
case targetVersion == -1:
|
|
return fmt.Errorf(errFmtMigrationPre1, "down to", fmt.Sprintf("you should downgrade to schema version 1 using the current authelia version then use the suggested authelia version to downgrade to pre1: %s", errFmtMigrationPre1SuggestedVersion))
|
|
}
|
|
|
|
if targetVersion == currentVersion {
|
|
return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion)
|
|
}
|
|
|
|
latest, err := latestMigrationVersion(providerName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if currentVersion > latest {
|
|
return fmt.Errorf(errFmtSchemaCurrentGreaterThanLatestKnown, latest)
|
|
}
|
|
|
|
if up {
|
|
if targetVersion < currentVersion {
|
|
return fmt.Errorf(ErrFmtMigrateUpTargetLessThanCurrent, targetVersion, currentVersion)
|
|
}
|
|
|
|
if targetVersion == SchemaLatest && latest == currentVersion {
|
|
return ErrSchemaAlreadyUpToDate
|
|
}
|
|
|
|
if targetVersion != SchemaLatest && latest < targetVersion {
|
|
return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest)
|
|
}
|
|
} else {
|
|
if targetVersion < 0 {
|
|
return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion)
|
|
}
|
|
|
|
if targetVersion > currentVersion {
|
|
return fmt.Errorf(ErrFmtMigrateDownTargetGreaterThanCurrent, targetVersion, currentVersion)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SchemaVersionToString returns a version string given a version number.
|
|
func SchemaVersionToString(version int) (versionStr string) {
|
|
switch version {
|
|
case -2:
|
|
return "unknown"
|
|
case -1:
|
|
return "pre1"
|
|
case 0:
|
|
return na
|
|
default:
|
|
return strconv.Itoa(version)
|
|
}
|
|
}
|