From dfbbf1a1f3a70bfdbed854549a0f8806c3e368bc Mon Sep 17 00:00:00 2001 From: James Elliott Date: Tue, 11 Apr 2023 21:11:11 +1000 Subject: [PATCH] fix(model): yaml encoding of totp and webauthn fails (#5204) This fixes an issue where the encoding of the YAML files fails when exporting TOTP/WebAuthn devices. Signed-off-by: James Elliott --- internal/model/totp_configuration.go | 39 ++++++++-- internal/model/totp_configuration_test.go | 62 ++++++++++++++++ internal/model/types_test.go | 10 --- internal/model/webauthn.go | 39 ++++++++-- internal/model/webauthn_test.go | 88 +++++++++++++++++++++++ 5 files changed, 218 insertions(+), 20 deletions(-) create mode 100644 internal/model/webauthn_test.go diff --git a/internal/model/totp_configuration.go b/internal/model/totp_configuration.go index a262ddb7b..0d325ce0a 100644 --- a/internal/model/totp_configuration.go +++ b/internal/model/totp_configuration.go @@ -25,9 +25,12 @@ type TOTPConfiguration struct { Secret []byte `db:"secret" json:"-"` } +// LastUsed provides LastUsedAt as a *time.Time instead of sql.NullTime. func (c *TOTPConfiguration) LastUsed() *time.Time { if c.LastUsedAt.Valid { - return &c.LastUsedAt.Time + value := time.Unix(c.LastUsedAt.Time.Unix(), int64(c.LastUsedAt.Time.Nanosecond())) + + return &value } return nil @@ -73,9 +76,9 @@ func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error return key.Image(width, height) } -// MarshalYAML marshals this model into YAML. -func (c *TOTPConfiguration) MarshalYAML() (any, error) { - o := TOTPConfigurationData{ +// ToData converts this TOTPConfiguration into the data format for exporting etc. +func (c *TOTPConfiguration) ToData() TOTPConfigurationData { + return TOTPConfigurationData{ CreatedAt: c.CreatedAt, LastUsedAt: c.LastUsed(), Username: c.Username, @@ -85,8 +88,11 @@ func (c *TOTPConfiguration) MarshalYAML() (any, error) { Period: c.Period, Secret: base64.StdEncoding.EncodeToString(c.Secret), } +} - return yaml.Marshal(o) +// MarshalYAML marshals this model into YAML. +func (c *TOTPConfiguration) MarshalYAML() (any, error) { + return c.ToData(), nil } // UnmarshalYAML unmarshalls YAML into this model. @@ -127,7 +133,30 @@ type TOTPConfigurationData struct { Secret string `yaml:"secret"` } +// TOTPConfigurationDataExport represents a TOTPConfiguration export file. +type TOTPConfigurationDataExport struct { + TOTPConfigurations []TOTPConfigurationData `yaml:"totp_configurations"` +} + // TOTPConfigurationExport represents a TOTPConfiguration export file. type TOTPConfigurationExport struct { TOTPConfigurations []TOTPConfiguration `yaml:"totp_configurations"` } + +// ToData converts this TOTPConfigurationExport into a TOTPConfigurationDataExport. +func (export TOTPConfigurationExport) ToData() TOTPConfigurationDataExport { + data := TOTPConfigurationDataExport{ + TOTPConfigurations: make([]TOTPConfigurationData, len(export.TOTPConfigurations)), + } + + for i, config := range export.TOTPConfigurations { + data.TOTPConfigurations[i] = config.ToData() + } + + return data +} + +// MarshalYAML marshals this model into YAML. +func (export TOTPConfigurationExport) MarshalYAML() (any, error) { + return export.ToData(), nil +} diff --git a/internal/model/totp_configuration_test.go b/internal/model/totp_configuration_test.go index 81d82aed3..07c38a1f7 100644 --- a/internal/model/totp_configuration_test.go +++ b/internal/model/totp_configuration_test.go @@ -1,11 +1,14 @@ package model import ( + "database/sql" "encoding/json" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) /* @@ -75,3 +78,62 @@ func TestShouldReturnImage(t *testing.T) { assert.Equal(t, 41, img.Bounds().Dx()) assert.Equal(t, 41, img.Bounds().Dy()) } + +func TestTOTPConfigurationImportExport(t *testing.T) { + have := TOTPConfigurationExport{ + TOTPConfigurations: []TOTPConfiguration{ + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Valid: false}, + Username: "john", + Issuer: "example", + Algorithm: "SHA1", + Digits: 6, + Period: 30, + Secret: MustRead(80), + }, + { + ID: 1, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true}, + Username: "abc", + Issuer: "example2", + Algorithm: "SHA512", + Digits: 8, + Period: 90, + Secret: MustRead(120), + }, + }, + } + + out, err := yaml.Marshal(&have) + require.NoError(t, err) + + imported := TOTPConfigurationExport{} + + require.NoError(t, yaml.Unmarshal(out, &imported)) + + require.Equal(t, len(have.TOTPConfigurations), len(imported.TOTPConfigurations)) + + for i, actual := range imported.TOTPConfigurations { + t.Run(actual.Username, func(t *testing.T) { + expected := have.TOTPConfigurations[i] + + if expected.ID != 0 { + assert.NotEqual(t, expected.ID, actual.ID) + } else { + assert.Equal(t, expected.ID, actual.ID) + } + + assert.Equal(t, expected.Username, actual.Username) + assert.Equal(t, expected.Issuer, actual.Issuer) + assert.Equal(t, expected.Algorithm, actual.Algorithm) + assert.Equal(t, expected.Digits, actual.Digits) + assert.Equal(t, expected.Period, actual.Period) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second) + assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid) + }) + } +} diff --git a/internal/model/types_test.go b/internal/model/types_test.go index 50d8bf38c..ada1e9bbb 100644 --- a/internal/model/types_test.go +++ b/internal/model/types_test.go @@ -1,21 +1,11 @@ package model import ( - "fmt" "testing" - "github.com/ory/fosite" "github.com/stretchr/testify/assert" ) -func Test(t *testing.T) { - args := fosite.Arguments{"abc", "123"} - - x := StringSlicePipeDelimited(args) - - fmt.Println(x) -} - func TestDatabaseModelTypeIP(t *testing.T) { ip := IP{} diff --git a/internal/model/webauthn.go b/internal/model/webauthn.go index bf53b9681..0921ccd24 100644 --- a/internal/model/webauthn.go +++ b/internal/model/webauthn.go @@ -172,17 +172,20 @@ func (d *WebAuthnDevice) UpdateSignInInfo(config *webauthn.Config, now time.Time } } +// LastUsed provides LastUsedAt as a *time.Time instead of sql.NullTime. func (d *WebAuthnDevice) LastUsed() *time.Time { if d.LastUsedAt.Valid { - return &d.LastUsedAt.Time + value := time.Unix(d.LastUsedAt.Time.Unix(), int64(d.LastUsedAt.Time.Nanosecond())) + + return &value } return nil } -// MarshalYAML marshals this model into YAML. -func (d *WebAuthnDevice) MarshalYAML() (any, error) { - o := WebAuthnDeviceData{ +// ToData converts this WebAuthnDevice into the data format for exporting etc. +func (d *WebAuthnDevice) ToData() WebAuthnDeviceData { + return WebAuthnDeviceData{ CreatedAt: d.CreatedAt, LastUsedAt: d.LastUsed(), RPID: d.RPID, @@ -196,8 +199,11 @@ func (d *WebAuthnDevice) MarshalYAML() (any, error) { SignCount: d.SignCount, CloneWarning: d.CloneWarning, } +} - return yaml.Marshal(o) +// MarshalYAML marshals this model into YAML. +func (d *WebAuthnDevice) MarshalYAML() (any, error) { + return d.ToData(), nil } // UnmarshalYAML unmarshalls YAML into this model. @@ -266,3 +272,26 @@ type WebAuthnDeviceData struct { type WebAuthnDeviceExport struct { WebAuthnDevices []WebAuthnDevice `yaml:"webauthn_devices"` } + +// WebAuthnDeviceDataExport represents a WebAuthnDevice export file. +type WebAuthnDeviceDataExport struct { + WebAuthnDevices []WebAuthnDeviceData `yaml:"webauthn_devices"` +} + +// ToData converts this WebAuthnDeviceExport into a WebAuthnDeviceDataExport. +func (export WebAuthnDeviceExport) ToData() WebAuthnDeviceDataExport { + data := WebAuthnDeviceDataExport{ + WebAuthnDevices: make([]WebAuthnDeviceData, len(export.WebAuthnDevices)), + } + + for i, device := range export.WebAuthnDevices { + data.WebAuthnDevices[i] = device.ToData() + } + + return data +} + +// MarshalYAML marshals this model into YAML. +func (export WebAuthnDeviceExport) MarshalYAML() (any, error) { + return export.ToData(), nil +} diff --git a/internal/model/webauthn_test.go b/internal/model/webauthn_test.go new file mode 100644 index 000000000..654c1a178 --- /dev/null +++ b/internal/model/webauthn_test.go @@ -0,0 +1,88 @@ +package model + +import ( + "crypto/rand" + "database/sql" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestWebAuthnDeviceImportExport(t *testing.T) { + have := WebAuthnDeviceExport{ + WebAuthnDevices: []WebAuthnDevice{ + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true}, + RPID: "example", + Username: "john", + Description: "akey", + KID: NewBase64(MustRead(20)), + PublicKey: MustRead(128), + AttestationType: "fido-u2f", + Transport: "", + AAGUID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + SignCount: 20, + CloneWarning: false, + }, + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Valid: false}, + RPID: "example2", + Username: "john2", + Description: "bkey", + KID: NewBase64(MustRead(60)), + PublicKey: MustRead(64), + AttestationType: "packed", + Transport: "", + AAGUID: uuid.NullUUID{Valid: false}, + SignCount: 30, + CloneWarning: true, + }, + }, + } + + out, err := yaml.Marshal(&have) + require.NoError(t, err) + + imported := WebAuthnDeviceExport{} + + require.NoError(t, yaml.Unmarshal(out, &imported)) + require.Equal(t, len(have.WebAuthnDevices), len(imported.WebAuthnDevices)) + + for i, actual := range imported.WebAuthnDevices { + t.Run(actual.Description, func(t *testing.T) { + expected := have.WebAuthnDevices[i] + + assert.Equal(t, expected.KID, actual.KID) + assert.Equal(t, expected.PublicKey, actual.PublicKey) + assert.Equal(t, expected.SignCount, actual.SignCount) + assert.Equal(t, expected.AttestationType, actual.AttestationType) + assert.Equal(t, expected.RPID, actual.RPID) + assert.Equal(t, expected.AAGUID.Valid, actual.AAGUID.Valid) + assert.Equal(t, expected.AAGUID.UUID, actual.AAGUID.UUID) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second) + assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid) + assert.Equal(t, expected.CloneWarning, actual.CloneWarning) + assert.Equal(t, expected.Description, actual.Description) + assert.Equal(t, expected.Username, actual.Username) + }) + } +} + +func MustRead(n int) []byte { + data := make([]byte, n) + + if _, err := rand.Read(data); err != nil { + panic(err) + } + + return data +}