fix(storage): postgresql webauthn tbl invalid aaguid constraint (#5183)

This fixes an issue with the PostgreSQL schema where the webauthn tables aaguid column had a NOT NULL constraint erroneously.

Fixes #5182

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
pull/5181/head
James Elliott 2023-04-08 11:36:34 +10:00 committed by GitHub
parent 3b52ddb137
commit fa250ea7dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 113 additions and 78 deletions

View File

@ -1,5 +1,9 @@
package model package model
import (
"strings"
)
// SchemaMigration represents an intended migration. // SchemaMigration represents an intended migration.
type SchemaMigration struct { type SchemaMigration struct {
Version int Version int
@ -9,6 +13,11 @@ type SchemaMigration struct {
Query string Query string
} }
// NotEmpty returns true if the SchemaMigration is not an empty string.
func (m SchemaMigration) NotEmpty() bool {
return len(strings.TrimSpace(m.Query)) != 0
}
// Before returns the version the schema should be at Before the migration is applied. // Before returns the version the schema should be at Before the migration is applied.
func (m SchemaMigration) Before() (before int) { func (m SchemaMigration) Before() (before int) {
if m.Up { if m.Up {

View File

@ -4,6 +4,7 @@ import (
"embed" "embed"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -15,8 +16,12 @@ import (
var migrationsFS embed.FS var migrationsFS embed.FS
func latestMigrationVersion(providerName string) (version int, err error) { func latestMigrationVersion(providerName string) (version int, err error) {
entries, err := migrationsFS.ReadDir("migrations") var (
if err != nil { entries []fs.DirEntry
migration model.SchemaMigration
)
if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
return -1, err return -1, err
} }
@ -25,21 +30,20 @@ func latestMigrationVersion(providerName string) (version int, err error) {
continue continue
} }
m, err := scanMigration(entry.Name()) if migration, err = scanMigration(entry.Name()); err != nil {
if err != nil {
return -1, err return -1, err
} }
if m.Provider != providerName { if migration.Provider != providerName && migration.Provider != providerAll {
continue continue
} }
if !m.Up { if !migration.Up {
continue continue
} }
if m.Version > version { if migration.Version > version {
version = m.Version version = migration.Version
} }
} }
@ -50,12 +54,17 @@ func latestMigrationVersion(providerName string) (version int, err error) {
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0 // target versions. If the target version is -1 this indicates the latest version. If the target version is 0
// this indicates the database zero state. // this indicates the database zero state.
func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) { func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) {
if prior == target && (prior != -1 || target != -1) { if prior == target {
return nil, ErrMigrateCurrentVersionSameAsTarget return nil, ErrMigrateCurrentVersionSameAsTarget
} }
entries, err := migrationsFS.ReadDir("migrations") var (
if err != nil { migrationsAll []model.SchemaMigration
migration model.SchemaMigration
entries []fs.DirEntry
)
if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
return nil, err return nil, err
} }
@ -66,8 +75,7 @@ func loadMigrations(providerName string, prior, target int) (migrations []model.
continue continue
} }
migration, err := scanMigration(entry.Name()) if migration, err = scanMigration(entry.Name()); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -75,8 +83,29 @@ func loadMigrations(providerName string, prior, target int) (migrations []model.
continue continue
} }
if migration.Provider == providerAll {
migrationsAll = append(migrationsAll, migration)
} else {
migrations = append(migrations, migration) migrations = append(migrations, migration)
} }
}
// Add "all" migrations for versions that don't exist.
for _, am := range migrationsAll {
found := false
for _, m := range migrations {
if m.Version == am.Version {
found = true
break
}
}
if !found {
migrations = append(migrations, am)
}
}
if up { if up {
sort.Slice(migrations, func(i, j int) bool { sort.Slice(migrations, func(i, j int) bool {
@ -103,7 +132,7 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
return true return true
} }
if target != -1 && (migration.Version > target || migration.Version <= prior) { if migration.Version > target || migration.Version <= prior {
// Skip if the migration version is greater than the target or less than or equal to the previous version. // Skip if the migration version is greater than the target or less than or equal to the previous version.
return true return true
} }
@ -113,12 +142,6 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
return true return true
} }
if migration.Version == 1 && target == -1 {
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
// preventing a successful migration.
return true
}
if migration.Version <= target || migration.Version > prior { if migration.Version <= target || migration.Version > prior {
// Skip the migration if we want to go down and the migration version is less than or equal to the target // Skip the migration if we want to go down and the migration version is less than or equal to the target
// or greater than the previous version. // or greater than the previous version.
@ -141,8 +164,9 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) {
Provider: result[reMigration.SubexpIndex("Provider")], Provider: result[reMigration.SubexpIndex("Provider")],
} }
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)) var data []byte
if err != nil {
if data, err = migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)); err != nil {
return model.SchemaMigration{}, err return model.SchemaMigration{}, err
} }

View File

@ -56,30 +56,8 @@ ALTER TABLE totp_configurations
DROP INDEX IF EXISTS totp_configurations_username_key1; DROP INDEX IF EXISTS totp_configurations_username_key1;
DROP INDEX IF EXISTS totp_configurations_username_key; DROP INDEX IF EXISTS totp_configurations_username_key;
ALTER TABLE totp_configurations
RENAME TO _bkp_UP_V0007_totp_configurations;
CREATE TABLE IF NOT EXISTS totp_configurations (
id SERIAL CONSTRAINT totp_configurations_pkey PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
username VARCHAR(100) NOT NULL,
issuer VARCHAR(100),
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
period INTEGER NOT NULL DEFAULT 30,
secret BYTEA NOT NULL
);
CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username); CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username);
INSERT INTO totp_configurations (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
SELECT created_at, last_used_at, username, issuer, algorithm, digits, period, secret
FROM _bkp_UP_V0007_totp_configurations
ORDER BY id;
DROP TABLE IF EXISTS _bkp_UP_V0007_totp_configurations;
ALTER TABLE webauthn_devices ALTER TABLE webauthn_devices
DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1, DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1,
DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1, DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1,
@ -97,34 +75,9 @@ DROP INDEX IF EXISTS webauthn_devices_username_description_key;
DROP INDEX IF EXISTS webauthn_devices_kid_key; DROP INDEX IF EXISTS webauthn_devices_kid_key;
DROP INDEX IF EXISTS webauthn_devices_lookup_key; DROP INDEX IF EXISTS webauthn_devices_lookup_key;
ALTER TABLE webauthn_devices
RENAME TO _bkp_UP_V0007_webauthn_devices;
CREATE TABLE IF NOT EXISTS webauthn_devices (
id SERIAL CONSTRAINT webauthn_devices_pkey PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
rpid TEXT,
username VARCHAR(100) NOT NULL,
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
kid VARCHAR(512) NOT NULL,
public_key BYTEA NOT NULL,
attestation_type VARCHAR(32),
transport VARCHAR(20) DEFAULT '',
aaguid CHAR(36) NOT NULL,
sign_count INTEGER DEFAULT 0,
clone_warning BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid); CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid);
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description); CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description);
INSERT INTO webauthn_devices (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
SELECT created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
FROM _bkp_UP_V0007_webauthn_devices;
DROP TABLE IF EXISTS _bkp_UP_V0007_webauthn_devices;
ALTER TABLE oauth2_consent_session ALTER TABLE oauth2_consent_session
DROP CONSTRAINT oauth2_consent_session_subject_fkey, DROP CONSTRAINT oauth2_consent_session_subject_fkey,
DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey; DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey;

View File

@ -0,0 +1,6 @@
ALTER TABLE webauthn_devices
ALTER COLUMN aaguid DROP NOT NULL;
UPDATE webauthn_devices
SET aaguid = NULL
WHERE aaguid = '' OR aaguid = '00000000-00000000-00000000-00000000';

View File

@ -9,7 +9,7 @@ import (
const ( const (
// This is the latest schema version for the purpose of tests. // This is the latest schema version for the purpose of tests.
LatestVersion = 8 LatestVersion = 9
) )
func TestShouldObtainCorrectUpMigrations(t *testing.T) { func TestShouldObtainCorrectUpMigrations(t *testing.T) {
@ -44,6 +44,47 @@ func TestShouldObtainCorrectDownMigrations(t *testing.T) {
} }
} }
func TestMigrationShouldGetSpecificMigrationIfAvaliable(t *testing.T) {
upMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 8, 9)
require.NoError(t, err)
require.Len(t, upMigrationsPostgreSQL, 1)
assert.True(t, upMigrationsPostgreSQL[0].Up)
assert.Equal(t, 9, upMigrationsPostgreSQL[0].Version)
assert.Equal(t, providerPostgres, upMigrationsPostgreSQL[0].Provider)
upMigrationsSQLite, err := loadMigrations(providerSQLite, 8, 9)
require.NoError(t, err)
require.Len(t, upMigrationsSQLite, 1)
assert.True(t, upMigrationsSQLite[0].Up)
assert.Equal(t, 9, upMigrationsSQLite[0].Version)
assert.Equal(t, providerAll, upMigrationsSQLite[0].Provider)
downMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 9, 8)
require.NoError(t, err)
require.Len(t, downMigrationsPostgreSQL, 1)
assert.False(t, downMigrationsPostgreSQL[0].Up)
assert.Equal(t, 9, downMigrationsPostgreSQL[0].Version)
assert.Equal(t, providerAll, downMigrationsPostgreSQL[0].Provider)
downMigrationsSQLite, err := loadMigrations(providerSQLite, 9, 8)
require.NoError(t, err)
require.Len(t, downMigrationsSQLite, 1)
assert.False(t, downMigrationsSQLite[0].Up)
assert.Equal(t, 9, downMigrationsSQLite[0].Version)
assert.Equal(t, providerAll, downMigrationsSQLite[0].Provider)
}
func TestMigrationShouldReturnErrorOnSame(t *testing.T) {
migrations, err := loadMigrations(providerPostgres, 1, 1)
assert.EqualError(t, err, "current version is same as migration target, no action being taken")
assert.Nil(t, migrations)
}
func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) { func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest) migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
require.NoError(t, err) require.NoError(t, err)

View File

@ -244,6 +244,7 @@ func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection
} }
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) { func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
if migration.NotEmpty() {
if _, err = conn.ExecContext(ctx, migration.Query); err != nil { if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
} }
@ -254,6 +255,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
return err return err
} }
} }
}
if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil { if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
return err return err