fix(model): yaml encoding of totp and webauthn fails (#5204)

This fixes an issue where the encoding of the YAML files fails when exporting TOTP/WebAuthn devices.

Signed-off-by: James Elliott <james-d-elliott@users.noreply.github.com>
pull/5129/head^2
James Elliott 2023-04-11 21:11:11 +10:00 committed by GitHub
parent 569af0fef0
commit dfbbf1a1f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 218 additions and 20 deletions

View File

@ -25,9 +25,12 @@ type TOTPConfiguration struct {
Secret []byte `db:"secret" json:"-"` Secret []byte `db:"secret" json:"-"`
} }
// LastUsed provides LastUsedAt as a *time.Time instead of sql.NullTime.
func (c *TOTPConfiguration) LastUsed() *time.Time { func (c *TOTPConfiguration) LastUsed() *time.Time {
if c.LastUsedAt.Valid { if c.LastUsedAt.Valid {
return &c.LastUsedAt.Time value := time.Unix(c.LastUsedAt.Time.Unix(), int64(c.LastUsedAt.Time.Nanosecond()))
return &value
} }
return nil return nil
@ -73,9 +76,9 @@ func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error
return key.Image(width, height) return key.Image(width, height)
} }
// MarshalYAML marshals this model into YAML. // ToData converts this TOTPConfiguration into the data format for exporting etc.
func (c *TOTPConfiguration) MarshalYAML() (any, error) { func (c *TOTPConfiguration) ToData() TOTPConfigurationData {
o := TOTPConfigurationData{ return TOTPConfigurationData{
CreatedAt: c.CreatedAt, CreatedAt: c.CreatedAt,
LastUsedAt: c.LastUsed(), LastUsedAt: c.LastUsed(),
Username: c.Username, Username: c.Username,
@ -85,8 +88,11 @@ func (c *TOTPConfiguration) MarshalYAML() (any, error) {
Period: c.Period, Period: c.Period,
Secret: base64.StdEncoding.EncodeToString(c.Secret), Secret: base64.StdEncoding.EncodeToString(c.Secret),
} }
}
return yaml.Marshal(o) // MarshalYAML marshals this model into YAML.
func (c *TOTPConfiguration) MarshalYAML() (any, error) {
return c.ToData(), nil
} }
// UnmarshalYAML unmarshalls YAML into this model. // UnmarshalYAML unmarshalls YAML into this model.
@ -127,7 +133,30 @@ type TOTPConfigurationData struct {
Secret string `yaml:"secret"` Secret string `yaml:"secret"`
} }
// TOTPConfigurationDataExport represents a TOTPConfiguration export file.
type TOTPConfigurationDataExport struct {
TOTPConfigurations []TOTPConfigurationData `yaml:"totp_configurations"`
}
// TOTPConfigurationExport represents a TOTPConfiguration export file. // TOTPConfigurationExport represents a TOTPConfiguration export file.
type TOTPConfigurationExport struct { type TOTPConfigurationExport struct {
TOTPConfigurations []TOTPConfiguration `yaml:"totp_configurations"` TOTPConfigurations []TOTPConfiguration `yaml:"totp_configurations"`
} }
// ToData converts this TOTPConfigurationExport into a TOTPConfigurationDataExport.
func (export TOTPConfigurationExport) ToData() TOTPConfigurationDataExport {
data := TOTPConfigurationDataExport{
TOTPConfigurations: make([]TOTPConfigurationData, len(export.TOTPConfigurations)),
}
for i, config := range export.TOTPConfigurations {
data.TOTPConfigurations[i] = config.ToData()
}
return data
}
// MarshalYAML marshals this model into YAML.
func (export TOTPConfigurationExport) MarshalYAML() (any, error) {
return export.ToData(), nil
}

View File

@ -1,11 +1,14 @@
package model package model
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
/* /*
@ -75,3 +78,62 @@ func TestShouldReturnImage(t *testing.T) {
assert.Equal(t, 41, img.Bounds().Dx()) assert.Equal(t, 41, img.Bounds().Dx())
assert.Equal(t, 41, img.Bounds().Dy()) assert.Equal(t, 41, img.Bounds().Dy())
} }
func TestTOTPConfigurationImportExport(t *testing.T) {
have := TOTPConfigurationExport{
TOTPConfigurations: []TOTPConfiguration{
{
ID: 0,
CreatedAt: time.Now(),
LastUsedAt: sql.NullTime{Valid: false},
Username: "john",
Issuer: "example",
Algorithm: "SHA1",
Digits: 6,
Period: 30,
Secret: MustRead(80),
},
{
ID: 1,
CreatedAt: time.Now(),
LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true},
Username: "abc",
Issuer: "example2",
Algorithm: "SHA512",
Digits: 8,
Period: 90,
Secret: MustRead(120),
},
},
}
out, err := yaml.Marshal(&have)
require.NoError(t, err)
imported := TOTPConfigurationExport{}
require.NoError(t, yaml.Unmarshal(out, &imported))
require.Equal(t, len(have.TOTPConfigurations), len(imported.TOTPConfigurations))
for i, actual := range imported.TOTPConfigurations {
t.Run(actual.Username, func(t *testing.T) {
expected := have.TOTPConfigurations[i]
if expected.ID != 0 {
assert.NotEqual(t, expected.ID, actual.ID)
} else {
assert.Equal(t, expected.ID, actual.ID)
}
assert.Equal(t, expected.Username, actual.Username)
assert.Equal(t, expected.Issuer, actual.Issuer)
assert.Equal(t, expected.Algorithm, actual.Algorithm)
assert.Equal(t, expected.Digits, actual.Digits)
assert.Equal(t, expected.Period, actual.Period)
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second)
assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second)
assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid)
})
}
}

View File

@ -1,21 +1,11 @@
package model package model
import ( import (
"fmt"
"testing" "testing"
"github.com/ory/fosite"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test(t *testing.T) {
args := fosite.Arguments{"abc", "123"}
x := StringSlicePipeDelimited(args)
fmt.Println(x)
}
func TestDatabaseModelTypeIP(t *testing.T) { func TestDatabaseModelTypeIP(t *testing.T) {
ip := IP{} ip := IP{}

View File

@ -172,17 +172,20 @@ func (d *WebAuthnDevice) UpdateSignInInfo(config *webauthn.Config, now time.Time
} }
} }
// LastUsed provides LastUsedAt as a *time.Time instead of sql.NullTime.
func (d *WebAuthnDevice) LastUsed() *time.Time { func (d *WebAuthnDevice) LastUsed() *time.Time {
if d.LastUsedAt.Valid { if d.LastUsedAt.Valid {
return &d.LastUsedAt.Time value := time.Unix(d.LastUsedAt.Time.Unix(), int64(d.LastUsedAt.Time.Nanosecond()))
return &value
} }
return nil return nil
} }
// MarshalYAML marshals this model into YAML. // ToData converts this WebAuthnDevice into the data format for exporting etc.
func (d *WebAuthnDevice) MarshalYAML() (any, error) { func (d *WebAuthnDevice) ToData() WebAuthnDeviceData {
o := WebAuthnDeviceData{ return WebAuthnDeviceData{
CreatedAt: d.CreatedAt, CreatedAt: d.CreatedAt,
LastUsedAt: d.LastUsed(), LastUsedAt: d.LastUsed(),
RPID: d.RPID, RPID: d.RPID,
@ -196,8 +199,11 @@ func (d *WebAuthnDevice) MarshalYAML() (any, error) {
SignCount: d.SignCount, SignCount: d.SignCount,
CloneWarning: d.CloneWarning, CloneWarning: d.CloneWarning,
} }
}
return yaml.Marshal(o) // MarshalYAML marshals this model into YAML.
func (d *WebAuthnDevice) MarshalYAML() (any, error) {
return d.ToData(), nil
} }
// UnmarshalYAML unmarshalls YAML into this model. // UnmarshalYAML unmarshalls YAML into this model.
@ -266,3 +272,26 @@ type WebAuthnDeviceData struct {
type WebAuthnDeviceExport struct { type WebAuthnDeviceExport struct {
WebAuthnDevices []WebAuthnDevice `yaml:"webauthn_devices"` WebAuthnDevices []WebAuthnDevice `yaml:"webauthn_devices"`
} }
// WebAuthnDeviceDataExport represents a WebAuthnDevice export file.
type WebAuthnDeviceDataExport struct {
WebAuthnDevices []WebAuthnDeviceData `yaml:"webauthn_devices"`
}
// ToData converts this WebAuthnDeviceExport into a WebAuthnDeviceDataExport.
func (export WebAuthnDeviceExport) ToData() WebAuthnDeviceDataExport {
data := WebAuthnDeviceDataExport{
WebAuthnDevices: make([]WebAuthnDeviceData, len(export.WebAuthnDevices)),
}
for i, device := range export.WebAuthnDevices {
data.WebAuthnDevices[i] = device.ToData()
}
return data
}
// MarshalYAML marshals this model into YAML.
func (export WebAuthnDeviceExport) MarshalYAML() (any, error) {
return export.ToData(), nil
}

View File

@ -0,0 +1,88 @@
package model
import (
"crypto/rand"
"database/sql"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestWebAuthnDeviceImportExport(t *testing.T) {
have := WebAuthnDeviceExport{
WebAuthnDevices: []WebAuthnDevice{
{
ID: 0,
CreatedAt: time.Now(),
LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true},
RPID: "example",
Username: "john",
Description: "akey",
KID: NewBase64(MustRead(20)),
PublicKey: MustRead(128),
AttestationType: "fido-u2f",
Transport: "",
AAGUID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
SignCount: 20,
CloneWarning: false,
},
{
ID: 0,
CreatedAt: time.Now(),
LastUsedAt: sql.NullTime{Valid: false},
RPID: "example2",
Username: "john2",
Description: "bkey",
KID: NewBase64(MustRead(60)),
PublicKey: MustRead(64),
AttestationType: "packed",
Transport: "",
AAGUID: uuid.NullUUID{Valid: false},
SignCount: 30,
CloneWarning: true,
},
},
}
out, err := yaml.Marshal(&have)
require.NoError(t, err)
imported := WebAuthnDeviceExport{}
require.NoError(t, yaml.Unmarshal(out, &imported))
require.Equal(t, len(have.WebAuthnDevices), len(imported.WebAuthnDevices))
for i, actual := range imported.WebAuthnDevices {
t.Run(actual.Description, func(t *testing.T) {
expected := have.WebAuthnDevices[i]
assert.Equal(t, expected.KID, actual.KID)
assert.Equal(t, expected.PublicKey, actual.PublicKey)
assert.Equal(t, expected.SignCount, actual.SignCount)
assert.Equal(t, expected.AttestationType, actual.AttestationType)
assert.Equal(t, expected.RPID, actual.RPID)
assert.Equal(t, expected.AAGUID.Valid, actual.AAGUID.Valid)
assert.Equal(t, expected.AAGUID.UUID, actual.AAGUID.UUID)
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second)
assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second)
assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid)
assert.Equal(t, expected.CloneWarning, actual.CloneWarning)
assert.Equal(t, expected.Description, actual.Description)
assert.Equal(t, expected.Username, actual.Username)
})
}
}
func MustRead(n int) []byte {
data := make([]byte, n)
if _, err := rand.Read(data); err != nil {
panic(err)
}
return data
}