diff --git a/internal/model/schema_migration.go b/internal/model/schema_migration.go index a18b8dc06..313b88436 100644 --- a/internal/model/schema_migration.go +++ b/internal/model/schema_migration.go @@ -1,5 +1,9 @@ package model +import ( + "strings" +) + // SchemaMigration represents an intended migration. type SchemaMigration struct { Version int @@ -9,6 +13,11 @@ type SchemaMigration struct { Query string } +// NotEmpty returns true if the SchemaMigration is not an empty string. +func (m SchemaMigration) NotEmpty() bool { + return len(strings.TrimSpace(m.Query)) != 0 +} + // Before returns the version the schema should be at Before the migration is applied. func (m SchemaMigration) Before() (before int) { if m.Up { diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index 18d79aa25..bfc78e954 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -4,6 +4,7 @@ import ( "embed" "errors" "fmt" + "io/fs" "sort" "strconv" "strings" @@ -15,8 +16,12 @@ import ( var migrationsFS embed.FS func latestMigrationVersion(providerName string) (version int, err error) { - entries, err := migrationsFS.ReadDir("migrations") - if err != nil { + var ( + entries []fs.DirEntry + migration model.SchemaMigration + ) + + if entries, err = migrationsFS.ReadDir("migrations"); err != nil { return -1, err } @@ -25,21 +30,20 @@ func latestMigrationVersion(providerName string) (version int, err error) { continue } - m, err := scanMigration(entry.Name()) - if err != nil { + if migration, err = scanMigration(entry.Name()); err != nil { return -1, err } - if m.Provider != providerName { + if migration.Provider != providerName && migration.Provider != providerAll { continue } - if !m.Up { + if !migration.Up { continue } - if m.Version > version { - version = m.Version + if migration.Version > version { + version = migration.Version } } @@ -50,12 +54,17 @@ func latestMigrationVersion(providerName string) (version int, err error) { // target versions. If the target version is -1 this indicates the latest version. If the target version is 0 // this indicates the database zero state. func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) { - if prior == target && (prior != -1 || target != -1) { + if prior == target { return nil, ErrMigrateCurrentVersionSameAsTarget } - entries, err := migrationsFS.ReadDir("migrations") - if err != nil { + var ( + migrationsAll []model.SchemaMigration + migration model.SchemaMigration + entries []fs.DirEntry + ) + + if entries, err = migrationsFS.ReadDir("migrations"); err != nil { return nil, err } @@ -66,8 +75,7 @@ func loadMigrations(providerName string, prior, target int) (migrations []model. continue } - migration, err := scanMigration(entry.Name()) - if err != nil { + if migration, err = scanMigration(entry.Name()); err != nil { return nil, err } @@ -75,7 +83,28 @@ func loadMigrations(providerName string, prior, target int) (migrations []model. continue } - migrations = append(migrations, migration) + if migration.Provider == providerAll { + migrationsAll = append(migrationsAll, migration) + } else { + migrations = append(migrations, migration) + } + } + + // Add "all" migrations for versions that don't exist. + for _, am := range migrationsAll { + found := false + + for _, m := range migrations { + if m.Version == am.Version { + found = true + + break + } + } + + if !found { + migrations = append(migrations, am) + } } if up { @@ -103,7 +132,7 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m return true } - if target != -1 && (migration.Version > target || migration.Version <= prior) { + if migration.Version > target || migration.Version <= prior { // Skip if the migration version is greater than the target or less than or equal to the previous version. return true } @@ -113,12 +142,6 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m return true } - if migration.Version == 1 && target == -1 { - // Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data - // preventing a successful migration. - return true - } - if migration.Version <= target || migration.Version > prior { // Skip the migration if we want to go down and the migration version is less than or equal to the target // or greater than the previous version. @@ -141,8 +164,9 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) { Provider: result[reMigration.SubexpIndex("Provider")], } - data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)) - if err != nil { + var data []byte + + if data, err = migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)); err != nil { return model.SchemaMigration{}, err } diff --git a/internal/storage/migrations/V0007.ConsistencyFixes.postgres.up.sql b/internal/storage/migrations/V0007.ConsistencyFixes.postgres.up.sql index a0c50cc90..52ad1df16 100644 --- a/internal/storage/migrations/V0007.ConsistencyFixes.postgres.up.sql +++ b/internal/storage/migrations/V0007.ConsistencyFixes.postgres.up.sql @@ -56,30 +56,8 @@ ALTER TABLE totp_configurations DROP INDEX IF EXISTS totp_configurations_username_key1; DROP INDEX IF EXISTS totp_configurations_username_key; -ALTER TABLE totp_configurations - RENAME TO _bkp_UP_V0007_totp_configurations; - -CREATE TABLE IF NOT EXISTS totp_configurations ( - id SERIAL CONSTRAINT totp_configurations_pkey PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, - username VARCHAR(100) NOT NULL, - issuer VARCHAR(100), - algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1', - digits INTEGER NOT NULL DEFAULT 6, - period INTEGER NOT NULL DEFAULT 30, - secret BYTEA NOT NULL -); - CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username); -INSERT INTO totp_configurations (created_at, last_used_at, username, issuer, algorithm, digits, period, secret) -SELECT created_at, last_used_at, username, issuer, algorithm, digits, period, secret -FROM _bkp_UP_V0007_totp_configurations -ORDER BY id; - -DROP TABLE IF EXISTS _bkp_UP_V0007_totp_configurations; - ALTER TABLE webauthn_devices DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1, DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1, @@ -97,34 +75,9 @@ DROP INDEX IF EXISTS webauthn_devices_username_description_key; DROP INDEX IF EXISTS webauthn_devices_kid_key; DROP INDEX IF EXISTS webauthn_devices_lookup_key; -ALTER TABLE webauthn_devices - RENAME TO _bkp_UP_V0007_webauthn_devices; - -CREATE TABLE IF NOT EXISTS webauthn_devices ( - id SERIAL CONSTRAINT webauthn_devices_pkey PRIMARY KEY, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, - rpid TEXT, - username VARCHAR(100) NOT NULL, - description VARCHAR(30) NOT NULL DEFAULT 'Primary', - kid VARCHAR(512) NOT NULL, - public_key BYTEA NOT NULL, - attestation_type VARCHAR(32), - transport VARCHAR(20) DEFAULT '', - aaguid CHAR(36) NOT NULL, - sign_count INTEGER DEFAULT 0, - clone_warning BOOLEAN NOT NULL DEFAULT FALSE -); - CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid); CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description); -INSERT INTO webauthn_devices (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning) -SELECT created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning -FROM _bkp_UP_V0007_webauthn_devices; - -DROP TABLE IF EXISTS _bkp_UP_V0007_webauthn_devices; - ALTER TABLE oauth2_consent_session DROP CONSTRAINT oauth2_consent_session_subject_fkey, DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey; diff --git a/internal/storage/migrations/V0009.FixConstraints.all.down.sql b/internal/storage/migrations/V0009.FixConstraints.all.down.sql new file mode 100644 index 000000000..e69de29bb diff --git a/internal/storage/migrations/V0009.FixConstraints.all.up.sql b/internal/storage/migrations/V0009.FixConstraints.all.up.sql new file mode 100644 index 000000000..e69de29bb diff --git a/internal/storage/migrations/V0009.FixConstraints.postgres.up.sql b/internal/storage/migrations/V0009.FixConstraints.postgres.up.sql new file mode 100644 index 000000000..6ea9834d6 --- /dev/null +++ b/internal/storage/migrations/V0009.FixConstraints.postgres.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE webauthn_devices + ALTER COLUMN aaguid DROP NOT NULL; + +UPDATE webauthn_devices +SET aaguid = NULL +WHERE aaguid = '' OR aaguid = '00000000-00000000-00000000-00000000'; diff --git a/internal/storage/migrations_test.go b/internal/storage/migrations_test.go index 3eeb155ea..ac898aa46 100644 --- a/internal/storage/migrations_test.go +++ b/internal/storage/migrations_test.go @@ -9,7 +9,7 @@ import ( const ( // This is the latest schema version for the purpose of tests. - LatestVersion = 8 + LatestVersion = 9 ) func TestShouldObtainCorrectUpMigrations(t *testing.T) { @@ -44,6 +44,47 @@ func TestShouldObtainCorrectDownMigrations(t *testing.T) { } } +func TestMigrationShouldGetSpecificMigrationIfAvaliable(t *testing.T) { + upMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 8, 9) + require.NoError(t, err) + require.Len(t, upMigrationsPostgreSQL, 1) + + assert.True(t, upMigrationsPostgreSQL[0].Up) + assert.Equal(t, 9, upMigrationsPostgreSQL[0].Version) + assert.Equal(t, providerPostgres, upMigrationsPostgreSQL[0].Provider) + + upMigrationsSQLite, err := loadMigrations(providerSQLite, 8, 9) + require.NoError(t, err) + require.Len(t, upMigrationsSQLite, 1) + + assert.True(t, upMigrationsSQLite[0].Up) + assert.Equal(t, 9, upMigrationsSQLite[0].Version) + assert.Equal(t, providerAll, upMigrationsSQLite[0].Provider) + + downMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 9, 8) + require.NoError(t, err) + require.Len(t, downMigrationsPostgreSQL, 1) + + assert.False(t, downMigrationsPostgreSQL[0].Up) + assert.Equal(t, 9, downMigrationsPostgreSQL[0].Version) + assert.Equal(t, providerAll, downMigrationsPostgreSQL[0].Provider) + + downMigrationsSQLite, err := loadMigrations(providerSQLite, 9, 8) + require.NoError(t, err) + require.Len(t, downMigrationsSQLite, 1) + + assert.False(t, downMigrationsSQLite[0].Up) + assert.Equal(t, 9, downMigrationsSQLite[0].Version) + assert.Equal(t, providerAll, downMigrationsSQLite[0].Provider) +} + +func TestMigrationShouldReturnErrorOnSame(t *testing.T) { + migrations, err := loadMigrations(providerPostgres, 1, 1) + + assert.EqualError(t, err, "current version is same as migration target, no action being taken") + assert.Nil(t, migrations) +} + func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) { migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest) require.NoError(t, err) diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index 8c015e963..c940e6b34 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -244,14 +244,16 @@ func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection } func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) { - if _, err = conn.ExecContext(ctx, migration.Query); err != nil { - return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) - } + if migration.NotEmpty() { + if _, err = conn.ExecContext(ctx, migration.Query); err != nil { + return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) + } - if migration.Version == 1 && migration.Up { - // Add the schema encryption value if upgrading to v1. - if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil { - return err + if migration.Version == 1 && migration.Up { + // Add the schema encryption value if upgrading to v1. + if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil { + return err + } } }