fix(storage): postgresql webauthn tbl invalid aaguid constraint (#5183)
This fixes an issue with the PostgreSQL schema where the webauthn tables aaguid column had a NOT NULL constraint erroneously. Fixes #5182 Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>pull/5181/head
parent
3b52ddb137
commit
fa250ea7dd
|
@ -1,5 +1,9 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// SchemaMigration represents an intended migration.
|
// SchemaMigration represents an intended migration.
|
||||||
type SchemaMigration struct {
|
type SchemaMigration struct {
|
||||||
Version int
|
Version int
|
||||||
|
@ -9,6 +13,11 @@ type SchemaMigration struct {
|
||||||
Query string
|
Query string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NotEmpty returns true if the SchemaMigration is not an empty string.
|
||||||
|
func (m SchemaMigration) NotEmpty() bool {
|
||||||
|
return len(strings.TrimSpace(m.Query)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
// Before returns the version the schema should be at Before the migration is applied.
|
// Before returns the version the schema should be at Before the migration is applied.
|
||||||
func (m SchemaMigration) Before() (before int) {
|
func (m SchemaMigration) Before() (before int) {
|
||||||
if m.Up {
|
if m.Up {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"embed"
|
"embed"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -15,8 +16,12 @@ import (
|
||||||
var migrationsFS embed.FS
|
var migrationsFS embed.FS
|
||||||
|
|
||||||
func latestMigrationVersion(providerName string) (version int, err error) {
|
func latestMigrationVersion(providerName string) (version int, err error) {
|
||||||
entries, err := migrationsFS.ReadDir("migrations")
|
var (
|
||||||
if err != nil {
|
entries []fs.DirEntry
|
||||||
|
migration model.SchemaMigration
|
||||||
|
)
|
||||||
|
|
||||||
|
if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,21 +30,20 @@ func latestMigrationVersion(providerName string) (version int, err error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := scanMigration(entry.Name())
|
if migration, err = scanMigration(entry.Name()); err != nil {
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Provider != providerName {
|
if migration.Provider != providerName && migration.Provider != providerAll {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.Up {
|
if !migration.Up {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Version > version {
|
if migration.Version > version {
|
||||||
version = m.Version
|
version = migration.Version
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,12 +54,17 @@ func latestMigrationVersion(providerName string) (version int, err error) {
|
||||||
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
|
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
|
||||||
// this indicates the database zero state.
|
// this indicates the database zero state.
|
||||||
func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) {
|
func loadMigrations(providerName string, prior, target int) (migrations []model.SchemaMigration, err error) {
|
||||||
if prior == target && (prior != -1 || target != -1) {
|
if prior == target {
|
||||||
return nil, ErrMigrateCurrentVersionSameAsTarget
|
return nil, ErrMigrateCurrentVersionSameAsTarget
|
||||||
}
|
}
|
||||||
|
|
||||||
entries, err := migrationsFS.ReadDir("migrations")
|
var (
|
||||||
if err != nil {
|
migrationsAll []model.SchemaMigration
|
||||||
|
migration model.SchemaMigration
|
||||||
|
entries []fs.DirEntry
|
||||||
|
)
|
||||||
|
|
||||||
|
if entries, err = migrationsFS.ReadDir("migrations"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,8 +75,7 @@ func loadMigrations(providerName string, prior, target int) (migrations []model.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
migration, err := scanMigration(entry.Name())
|
if migration, err = scanMigration(entry.Name()); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,8 +83,29 @@ func loadMigrations(providerName string, prior, target int) (migrations []model.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if migration.Provider == providerAll {
|
||||||
|
migrationsAll = append(migrationsAll, migration)
|
||||||
|
} else {
|
||||||
migrations = append(migrations, migration)
|
migrations = append(migrations, migration)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add "all" migrations for versions that don't exist.
|
||||||
|
for _, am := range migrationsAll {
|
||||||
|
found := false
|
||||||
|
|
||||||
|
for _, m := range migrations {
|
||||||
|
if m.Version == am.Version {
|
||||||
|
found = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
migrations = append(migrations, am)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if up {
|
if up {
|
||||||
sort.Slice(migrations, func(i, j int) bool {
|
sort.Slice(migrations, func(i, j int) bool {
|
||||||
|
@ -103,7 +132,7 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if target != -1 && (migration.Version > target || migration.Version <= prior) {
|
if migration.Version > target || migration.Version <= prior {
|
||||||
// Skip if the migration version is greater than the target or less than or equal to the previous version.
|
// Skip if the migration version is greater than the target or less than or equal to the previous version.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -113,12 +142,6 @@ func skipMigration(providerName string, up bool, target, prior int, migration *m
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if migration.Version == 1 && target == -1 {
|
|
||||||
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
|
|
||||||
// preventing a successful migration.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if migration.Version <= target || migration.Version > prior {
|
if migration.Version <= target || migration.Version > prior {
|
||||||
// Skip the migration if we want to go down and the migration version is less than or equal to the target
|
// Skip the migration if we want to go down and the migration version is less than or equal to the target
|
||||||
// or greater than the previous version.
|
// or greater than the previous version.
|
||||||
|
@ -141,8 +164,9 @@ func scanMigration(m string) (migration model.SchemaMigration, err error) {
|
||||||
Provider: result[reMigration.SubexpIndex("Provider")],
|
Provider: result[reMigration.SubexpIndex("Provider")],
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
|
var data []byte
|
||||||
if err != nil {
|
|
||||||
|
if data, err = migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)); err != nil {
|
||||||
return model.SchemaMigration{}, err
|
return model.SchemaMigration{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,30 +56,8 @@ ALTER TABLE totp_configurations
|
||||||
DROP INDEX IF EXISTS totp_configurations_username_key1;
|
DROP INDEX IF EXISTS totp_configurations_username_key1;
|
||||||
DROP INDEX IF EXISTS totp_configurations_username_key;
|
DROP INDEX IF EXISTS totp_configurations_username_key;
|
||||||
|
|
||||||
ALTER TABLE totp_configurations
|
|
||||||
RENAME TO _bkp_UP_V0007_totp_configurations;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS totp_configurations (
|
|
||||||
id SERIAL CONSTRAINT totp_configurations_pkey PRIMARY KEY,
|
|
||||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
|
||||||
username VARCHAR(100) NOT NULL,
|
|
||||||
issuer VARCHAR(100),
|
|
||||||
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
|
|
||||||
digits INTEGER NOT NULL DEFAULT 6,
|
|
||||||
period INTEGER NOT NULL DEFAULT 30,
|
|
||||||
secret BYTEA NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username);
|
CREATE UNIQUE INDEX totp_configurations_username_key ON totp_configurations (username);
|
||||||
|
|
||||||
INSERT INTO totp_configurations (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
|
|
||||||
SELECT created_at, last_used_at, username, issuer, algorithm, digits, period, secret
|
|
||||||
FROM _bkp_UP_V0007_totp_configurations
|
|
||||||
ORDER BY id;
|
|
||||||
|
|
||||||
DROP TABLE IF EXISTS _bkp_UP_V0007_totp_configurations;
|
|
||||||
|
|
||||||
ALTER TABLE webauthn_devices
|
ALTER TABLE webauthn_devices
|
||||||
DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1,
|
DROP CONSTRAINT IF EXISTS webauthn_devices_username_description_key1,
|
||||||
DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1,
|
DROP CONSTRAINT IF EXISTS webauthn_devices_kid_key1,
|
||||||
|
@ -97,34 +75,9 @@ DROP INDEX IF EXISTS webauthn_devices_username_description_key;
|
||||||
DROP INDEX IF EXISTS webauthn_devices_kid_key;
|
DROP INDEX IF EXISTS webauthn_devices_kid_key;
|
||||||
DROP INDEX IF EXISTS webauthn_devices_lookup_key;
|
DROP INDEX IF EXISTS webauthn_devices_lookup_key;
|
||||||
|
|
||||||
ALTER TABLE webauthn_devices
|
|
||||||
RENAME TO _bkp_UP_V0007_webauthn_devices;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS webauthn_devices (
|
|
||||||
id SERIAL CONSTRAINT webauthn_devices_pkey PRIMARY KEY,
|
|
||||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_used_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
|
||||||
rpid TEXT,
|
|
||||||
username VARCHAR(100) NOT NULL,
|
|
||||||
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
|
|
||||||
kid VARCHAR(512) NOT NULL,
|
|
||||||
public_key BYTEA NOT NULL,
|
|
||||||
attestation_type VARCHAR(32),
|
|
||||||
transport VARCHAR(20) DEFAULT '',
|
|
||||||
aaguid CHAR(36) NOT NULL,
|
|
||||||
sign_count INTEGER DEFAULT 0,
|
|
||||||
clone_warning BOOLEAN NOT NULL DEFAULT FALSE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid);
|
CREATE UNIQUE INDEX webauthn_devices_kid_key ON webauthn_devices (kid);
|
||||||
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description);
|
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (username, description);
|
||||||
|
|
||||||
INSERT INTO webauthn_devices (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
|
|
||||||
SELECT created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
|
|
||||||
FROM _bkp_UP_V0007_webauthn_devices;
|
|
||||||
|
|
||||||
DROP TABLE IF EXISTS _bkp_UP_V0007_webauthn_devices;
|
|
||||||
|
|
||||||
ALTER TABLE oauth2_consent_session
|
ALTER TABLE oauth2_consent_session
|
||||||
DROP CONSTRAINT oauth2_consent_session_subject_fkey,
|
DROP CONSTRAINT oauth2_consent_session_subject_fkey,
|
||||||
DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey;
|
DROP CONSTRAINT oauth2_consent_session_preconfiguration_fkey;
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
ALTER TABLE webauthn_devices
|
||||||
|
ALTER COLUMN aaguid DROP NOT NULL;
|
||||||
|
|
||||||
|
UPDATE webauthn_devices
|
||||||
|
SET aaguid = NULL
|
||||||
|
WHERE aaguid = '' OR aaguid = '00000000-00000000-00000000-00000000';
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// This is the latest schema version for the purpose of tests.
|
// This is the latest schema version for the purpose of tests.
|
||||||
LatestVersion = 8
|
LatestVersion = 9
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
||||||
|
@ -44,6 +44,47 @@ func TestShouldObtainCorrectDownMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrationShouldGetSpecificMigrationIfAvaliable(t *testing.T) {
|
||||||
|
upMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 8, 9)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upMigrationsPostgreSQL, 1)
|
||||||
|
|
||||||
|
assert.True(t, upMigrationsPostgreSQL[0].Up)
|
||||||
|
assert.Equal(t, 9, upMigrationsPostgreSQL[0].Version)
|
||||||
|
assert.Equal(t, providerPostgres, upMigrationsPostgreSQL[0].Provider)
|
||||||
|
|
||||||
|
upMigrationsSQLite, err := loadMigrations(providerSQLite, 8, 9)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, upMigrationsSQLite, 1)
|
||||||
|
|
||||||
|
assert.True(t, upMigrationsSQLite[0].Up)
|
||||||
|
assert.Equal(t, 9, upMigrationsSQLite[0].Version)
|
||||||
|
assert.Equal(t, providerAll, upMigrationsSQLite[0].Provider)
|
||||||
|
|
||||||
|
downMigrationsPostgreSQL, err := loadMigrations(providerPostgres, 9, 8)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, downMigrationsPostgreSQL, 1)
|
||||||
|
|
||||||
|
assert.False(t, downMigrationsPostgreSQL[0].Up)
|
||||||
|
assert.Equal(t, 9, downMigrationsPostgreSQL[0].Version)
|
||||||
|
assert.Equal(t, providerAll, downMigrationsPostgreSQL[0].Provider)
|
||||||
|
|
||||||
|
downMigrationsSQLite, err := loadMigrations(providerSQLite, 9, 8)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, downMigrationsSQLite, 1)
|
||||||
|
|
||||||
|
assert.False(t, downMigrationsSQLite[0].Up)
|
||||||
|
assert.Equal(t, 9, downMigrationsSQLite[0].Version)
|
||||||
|
assert.Equal(t, providerAll, downMigrationsSQLite[0].Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrationShouldReturnErrorOnSame(t *testing.T) {
|
||||||
|
migrations, err := loadMigrations(providerPostgres, 1, 1)
|
||||||
|
|
||||||
|
assert.EqualError(t, err, "current version is same as migration target, no action being taken")
|
||||||
|
assert.Nil(t, migrations)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
|
func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
|
||||||
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
|
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -244,6 +244,7 @@ func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
|
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
|
||||||
|
if migration.NotEmpty() {
|
||||||
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
|
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
|
||||||
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
||||||
}
|
}
|
||||||
|
@ -254,6 +255,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnectio
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
|
if err = p.schemaMigrateFinalize(ctx, conn, migration); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue