fix: multi-cookie domain webauthn

feat-otp-verification
James Elliott 2023-02-12 02:47:03 +11:00
parent 8c057f65a5
commit 3b6f5482b8
No known key found for this signature in database
GPG Key ID: 0F1C4A096E857E49
24 changed files with 363 additions and 54 deletions

View File

@ -503,8 +503,8 @@ more information.
{{< confkey type="string" default="auto" required="no" >}}
*__Important Note:__ the `implicit` consent mode is not technically part of the specification. It theoretically could be
misused in certain conditions specifically with public clients or when the client credentials (i.e. client secret) has
been exposed to an attacker. For these reasons this mode is discouraged.*
misused in certain conditions specifically with the public client type or when the client credentials (i.e. client
secret) has been exposed to an attacker. For these reasons this mode is discouraged.*
Configures the consent mode. The following table describes the different modes:

View File

@ -7,6 +7,7 @@ package authentication
import (
tls "crypto/tls"
reflect "reflect"
time "time"
ldap "github.com/go-ldap/ldap/v3"
gomock "github.com/golang/mock/gomock"
@ -35,6 +36,20 @@ func (m *MockLDAPClient) EXPECT() *MockLDAPClientMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockLDAPClient) Add(arg0 *ldap.AddRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Add indicates an expected call of Add.
func (mr *MockLDAPClientMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockLDAPClient)(nil).Add), arg0)
}
// Bind mocks base method.
func (m *MockLDAPClient) Bind(arg0, arg1 string) error {
m.ctrl.T.Helper()
@ -61,6 +76,92 @@ func (mr *MockLDAPClientMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLDAPClient)(nil).Close))
}
// Compare mocks base method.
func (m *MockLDAPClient) Compare(arg0, arg1, arg2 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Compare", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Compare indicates an expected call of Compare.
func (mr *MockLDAPClientMockRecorder) Compare(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Compare", reflect.TypeOf((*MockLDAPClient)(nil).Compare), arg0, arg1, arg2)
}
// Del mocks base method.
func (m *MockLDAPClient) Del(arg0 *ldap.DelRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Del", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Del indicates an expected call of Del.
func (mr *MockLDAPClientMockRecorder) Del(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockLDAPClient)(nil).Del), arg0)
}
// DigestMD5Bind mocks base method.
func (m *MockLDAPClient) DigestMD5Bind(arg0 *ldap.DigestMD5BindRequest) (*ldap.DigestMD5BindResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DigestMD5Bind", arg0)
ret0, _ := ret[0].(*ldap.DigestMD5BindResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DigestMD5Bind indicates an expected call of DigestMD5Bind.
func (mr *MockLDAPClientMockRecorder) DigestMD5Bind(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DigestMD5Bind", reflect.TypeOf((*MockLDAPClient)(nil).DigestMD5Bind), arg0)
}
// ExternalBind mocks base method.
func (m *MockLDAPClient) ExternalBind() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExternalBind")
ret0, _ := ret[0].(error)
return ret0
}
// ExternalBind indicates an expected call of ExternalBind.
func (mr *MockLDAPClientMockRecorder) ExternalBind() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExternalBind", reflect.TypeOf((*MockLDAPClient)(nil).ExternalBind))
}
// IsClosing mocks base method.
func (m *MockLDAPClient) IsClosing() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosing")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosing indicates an expected call of IsClosing.
func (mr *MockLDAPClientMockRecorder) IsClosing() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosing", reflect.TypeOf((*MockLDAPClient)(nil).IsClosing))
}
// MD5Bind mocks base method.
func (m *MockLDAPClient) MD5Bind(arg0, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MD5Bind", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// MD5Bind indicates an expected call of MD5Bind.
func (mr *MockLDAPClientMockRecorder) MD5Bind(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MD5Bind", reflect.TypeOf((*MockLDAPClient)(nil).MD5Bind), arg0, arg1, arg2)
}
// Modify mocks base method.
func (m *MockLDAPClient) Modify(arg0 *ldap.ModifyRequest) error {
m.ctrl.T.Helper()
@ -75,6 +176,35 @@ func (mr *MockLDAPClientMockRecorder) Modify(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Modify", reflect.TypeOf((*MockLDAPClient)(nil).Modify), arg0)
}
// ModifyDN mocks base method.
func (m *MockLDAPClient) ModifyDN(arg0 *ldap.ModifyDNRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ModifyDN", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ModifyDN indicates an expected call of ModifyDN.
func (mr *MockLDAPClientMockRecorder) ModifyDN(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyDN", reflect.TypeOf((*MockLDAPClient)(nil).ModifyDN), arg0)
}
// ModifyWithResult mocks base method.
func (m *MockLDAPClient) ModifyWithResult(arg0 *ldap.ModifyRequest) (*ldap.ModifyResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ModifyWithResult", arg0)
ret0, _ := ret[0].(*ldap.ModifyResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ModifyWithResult indicates an expected call of ModifyWithResult.
func (mr *MockLDAPClientMockRecorder) ModifyWithResult(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyWithResult", reflect.TypeOf((*MockLDAPClient)(nil).ModifyWithResult), arg0)
}
// PasswordModify mocks base method.
func (m *MockLDAPClient) PasswordModify(arg0 *ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) {
m.ctrl.T.Helper()
@ -105,6 +235,60 @@ func (mr *MockLDAPClientMockRecorder) Search(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockLDAPClient)(nil).Search), arg0)
}
// SearchWithPaging mocks base method.
func (m *MockLDAPClient) SearchWithPaging(arg0 *ldap.SearchRequest, arg1 uint32) (*ldap.SearchResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SearchWithPaging", arg0, arg1)
ret0, _ := ret[0].(*ldap.SearchResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SearchWithPaging indicates an expected call of SearchWithPaging.
func (mr *MockLDAPClientMockRecorder) SearchWithPaging(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchWithPaging", reflect.TypeOf((*MockLDAPClient)(nil).SearchWithPaging), arg0, arg1)
}
// SetTimeout mocks base method.
func (m *MockLDAPClient) SetTimeout(arg0 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTimeout", arg0)
}
// SetTimeout indicates an expected call of SetTimeout.
func (mr *MockLDAPClientMockRecorder) SetTimeout(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTimeout", reflect.TypeOf((*MockLDAPClient)(nil).SetTimeout), arg0)
}
// SimpleBind mocks base method.
func (m *MockLDAPClient) SimpleBind(arg0 *ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SimpleBind", arg0)
ret0, _ := ret[0].(*ldap.SimpleBindResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SimpleBind indicates an expected call of SimpleBind.
func (mr *MockLDAPClientMockRecorder) SimpleBind(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SimpleBind", reflect.TypeOf((*MockLDAPClient)(nil).SimpleBind), arg0)
}
// Start mocks base method.
func (m *MockLDAPClient) Start() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Start")
}
// Start indicates an expected call of Start.
func (mr *MockLDAPClientMockRecorder) Start() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockLDAPClient)(nil).Start))
}
// StartTLS mocks base method.
func (m *MockLDAPClient) StartTLS(arg0 *tls.Config) error {
m.ctrl.T.Helper()
@ -119,6 +303,21 @@ func (mr *MockLDAPClientMockRecorder) StartTLS(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTLS", reflect.TypeOf((*MockLDAPClient)(nil).StartTLS), arg0)
}
// TLSConnectionState mocks base method.
func (m *MockLDAPClient) TLSConnectionState() (tls.ConnectionState, bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TLSConnectionState")
ret0, _ := ret[0].(tls.ConnectionState)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// TLSConnectionState indicates an expected call of TLSConnectionState.
func (mr *MockLDAPClientMockRecorder) TLSConnectionState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TLSConnectionState", reflect.TypeOf((*MockLDAPClient)(nil).TLSConnectionState))
}
// UnauthenticatedBind mocks base method.
func (m *MockLDAPClient) UnauthenticatedBind(arg0 string) error {
m.ctrl.T.Helper()
@ -132,3 +331,32 @@ func (mr *MockLDAPClientMockRecorder) UnauthenticatedBind(arg0 interface{}) *gom
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnauthenticatedBind", reflect.TypeOf((*MockLDAPClient)(nil).UnauthenticatedBind), arg0)
}
// Unbind mocks base method.
func (m *MockLDAPClient) Unbind() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unbind")
ret0, _ := ret[0].(error)
return ret0
}
// Unbind indicates an expected call of Unbind.
func (mr *MockLDAPClientMockRecorder) Unbind() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unbind", reflect.TypeOf((*MockLDAPClient)(nil).Unbind))
}
// WhoAmI mocks base method.
func (m *MockLDAPClient) WhoAmI(arg0 []ldap.Control) (*ldap.WhoAmIResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WhoAmI", arg0)
ret0, _ := ret[0].(*ldap.WhoAmIResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WhoAmI indicates an expected call of WhoAmI.
func (mr *MockLDAPClientMockRecorder) WhoAmI(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WhoAmI", reflect.TypeOf((*MockLDAPClient)(nil).WhoAmI), arg0)
}

View File

@ -3,6 +3,7 @@ package authentication
import (
"crypto/tls"
"net/mail"
"time"
"github.com/go-ldap/ldap/v3"
"golang.org/x/text/encoding/unicode"
@ -17,16 +18,35 @@ type LDAPClientFactory interface {
//
// Methods added to this interface that have a direct correlation with one from ldap.Client should have the same signature.
type LDAPClient interface {
Start()
Close()
IsClosing() bool
SetTimeout(timeout time.Duration)
TLSConnectionState() (state tls.ConnectionState, ok bool)
StartTLS(config *tls.Config) (err error)
Unbind() (err error)
Bind(username, password string) (err error)
SimpleBind(simpleBindRequest *ldap.SimpleBindRequest) (bindResult *ldap.SimpleBindResult, err error)
MD5Bind(host string, username string, password string) (err error)
DigestMD5Bind(digestMD5BindRequest *ldap.DigestMD5BindRequest) (digestMD5BindResult *ldap.DigestMD5BindResult, err error)
UnauthenticatedBind(username string) (err error)
ExternalBind() (err error)
Modify(modifyRequest *ldap.ModifyRequest) (err error)
ModifyWithResult(modifyRequest *ldap.ModifyRequest) (modifyResult *ldap.ModifyResult, err error)
ModifyDN(m *ldap.ModifyDNRequest) (err error)
PasswordModify(pwdModifyRequest *ldap.PasswordModifyRequest) (pwdModifyResult *ldap.PasswordModifyResult, err error)
Add(addRequest *ldap.AddRequest) (err error)
Del(delRequest *ldap.DelRequest) (err error)
Search(searchRequest *ldap.SearchRequest) (searchResult *ldap.SearchResult, err error)
SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (searchResult *ldap.SearchResult, err error)
Compare(dn string, attribute string, value string) (same bool, err error)
WhoAmI(controls []ldap.Control) (whoamiResult *ldap.WhoAmIResult, err error)
}
// UserDetails represent the details retrieved for a given user.

View File

@ -550,7 +550,7 @@ func (ctx *CmdCtx) StorageUserWebauthnListRunE(cmd *cobra.Command, args []string
user := args[0]
devices, err = ctx.providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, user)
devices, err = ctx.providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, "", user)
switch {
case len(devices) == 0 || (err != nil && errors.Is(err, storage.ErrNoWebauthnDevice)):

View File

@ -40,7 +40,7 @@ func WebauthnRegistrationGET(ctx *middlewares.AutheliaCtx) {
return
}
if user, err = getWebAuthnUser(ctx, userSession); err != nil {
if user, err = getWebauthnUserByRPID(ctx, userSession, w.Config.RPID); err != nil {
ctx.Logger.Errorf("Unable to load %s devices for assertion challenge for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)
respondUnauthorized(ctx, messageMFAValidationFailed)
@ -132,7 +132,7 @@ func WebauthnRegistrationPOST(ctx *middlewares.AutheliaCtx) {
ctx.Logger.WithField("att_format", response.Response.AttestationObject.Format).Debug("Response Data")
if user, err = getWebAuthnUser(ctx, userSession); err != nil {
if user, err = getWebauthnUser(ctx, userSession); err != nil {
ctx.Logger.Errorf("Unable to load %s user details for registration for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)
respondUnauthorized(ctx, messageMFAValidationFailed)
@ -150,7 +150,7 @@ func WebauthnRegistrationPOST(ctx *middlewares.AutheliaCtx) {
ctx.Logger.WithField("att_type", credential.AttestationType).Debug("Credential Data")
devices, err := ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, userSession.Username)
devices, err := ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, w.Config.RPID, userSession.Username)
if err != nil && err != storage.ErrNoWebauthnDevice {
ctx.Logger.Errorf("Unable to load existing %s devices for for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)
@ -173,6 +173,13 @@ func WebauthnRegistrationPOST(ctx *middlewares.AutheliaCtx) {
device := model.NewWebauthnDeviceFromCredential(w.Config.RPID, userSession.Username, bodyJSON.Description, credential)
ctx.Logger.WithFields(map[string]any{
"RPID": device.RPID,
"ID": device.ID,
"KID": device.KID.String(),
"User": device.Username,
}).Debug("Registering New Device")
if err = ctx.Providers.StorageProvider.SaveWebauthnDevice(ctx, device); err != nil {
ctx.Logger.Errorf("Unable to save %s device registration for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)

View File

@ -37,7 +37,7 @@ func WebauthnAssertionGET(ctx *middlewares.AutheliaCtx) {
return
}
if user, err = getWebAuthnUser(ctx, userSession); err != nil {
if user, err = getWebauthnUserByRPID(ctx, userSession, w.Config.RPID); err != nil {
ctx.Logger.Errorf("Unable to load %s user details during authentication challenge for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)
respondUnauthorized(ctx, messageMFAValidationFailed)
@ -145,7 +145,7 @@ func WebauthnAssertionPOST(ctx *middlewares.AutheliaCtx) {
return
}
if user, err = getWebAuthnUser(ctx, userSession); err != nil {
if user, err = getWebauthnUserByRPID(ctx, userSession, w.Config.RPID); err != nil {
ctx.Logger.Errorf("Unable to load %s credentials for authentication challenge for user '%s': %+v", regulation.AuthTypeWebauthn, userSession.Username, err)
respondUnauthorized(ctx, messageMFAValidationFailed)

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/url"
"strconv"
"github.com/valyala/fasthttp"
@ -37,6 +38,7 @@ func getWebauthnDeviceIDFromContext(ctx *middlewares.AutheliaCtx) (int, error) {
func WebauthnDevicesGET(ctx *middlewares.AutheliaCtx) {
var (
userSession session.UserSession
origin *url.URL
err error
)
@ -48,7 +50,15 @@ func WebauthnDevicesGET(ctx *middlewares.AutheliaCtx) {
return
}
devices, err := ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, userSession.Username)
if origin, err = ctx.GetOrigin(); err != nil {
ctx.Logger.WithError(err).Error("Error occurred retrieving origin")
ctx.ReplyForbidden()
return
}
devices, err := ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, origin.Hostname(), userSession.Username)
if err != nil && err != storage.ErrNoWebauthnDevice {
ctx.Error(err, messageOperationFailed)

View File

@ -1,7 +1,6 @@
package handlers
import (
"fmt"
"net/url"
"strings"
@ -13,7 +12,11 @@ import (
"github.com/authelia/authelia/v4/internal/session"
)
func getWebAuthnUser(ctx *middlewares.AutheliaCtx, userSession session.UserSession) (user *model.WebauthnUser, err error) {
func getWebauthnUser(ctx *middlewares.AutheliaCtx, userSession session.UserSession) (user *model.WebauthnUser, err error) {
return getWebauthnUserByRPID(ctx, userSession, "")
}
func getWebauthnUserByRPID(ctx *middlewares.AutheliaCtx, userSession session.UserSession, rpid string) (user *model.WebauthnUser, err error) {
user = &model.WebauthnUser{
Username: userSession.Username,
DisplayName: userSession.DisplayName,
@ -23,7 +26,7 @@ func getWebAuthnUser(ctx *middlewares.AutheliaCtx, userSession session.UserSessi
user.DisplayName = user.Username
}
if user.Devices, err = ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, userSession.Username); err != nil {
if user.Devices, err = ctx.Providers.StorageProvider.LoadWebauthnDevicesByUsername(ctx, rpid, userSession.Username); err != nil {
return nil, err
}
@ -32,20 +35,17 @@ func getWebAuthnUser(ctx *middlewares.AutheliaCtx, userSession session.UserSessi
func newWebauthn(ctx *middlewares.AutheliaCtx) (w *webauthn.WebAuthn, err error) {
var (
u *url.URL
origin *url.URL
)
if u, err = ctx.GetXOriginalURLOrXForwardedURL(); err != nil {
if origin, err = ctx.GetOrigin(); err != nil {
return nil, err
}
rpID := u.Hostname()
origin := fmt.Sprintf("%s://%s", u.Scheme, u.Host)
config := &webauthn.Config{
RPDisplayName: ctx.Configuration.Webauthn.DisplayName,
RPID: rpID,
RPOrigins: []string{origin},
RPID: origin.Hostname(),
RPOrigins: []string{origin.String()},
RPIcon: "",
AttestationPreference: ctx.Configuration.Webauthn.ConveyancePreference,

View File

@ -21,10 +21,10 @@ func TestWebauthnGetUser(t *testing.T) {
DisplayName: "John Smith",
}
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "john").Return([]model.WebauthnDevice{
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "example.com", "john").Return([]model.WebauthnDevice{
{
ID: 1,
RPID: "https://example.com",
RPID: "example.com",
Username: "john",
Description: "Primary",
KID: model.NewBase64([]byte("abc123")),
@ -47,7 +47,7 @@ func TestWebauthnGetUser(t *testing.T) {
},
}, nil)
user, err := getWebAuthnUser(ctx.Ctx, userSession)
user, err := getWebauthnUserByRPID(ctx.Ctx, userSession, "example.com")
require.NoError(t, err)
require.NotNil(t, user)
@ -64,7 +64,7 @@ func TestWebauthnGetUser(t *testing.T) {
require.Len(t, user.Devices, 2)
assert.Equal(t, 1, user.Devices[0].ID)
assert.Equal(t, "https://example.com", user.Devices[0].RPID)
assert.Equal(t, "example.com", user.Devices[0].RPID)
assert.Equal(t, "john", user.Devices[0].Username)
assert.Equal(t, "Primary", user.Devices[0].Description)
assert.Equal(t, "", user.Devices[0].Transport)
@ -106,10 +106,10 @@ func TestWebauthnGetUserWithoutDisplayName(t *testing.T) {
Username: "john",
}
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "john").Return([]model.WebauthnDevice{
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "example.com", "john").Return([]model.WebauthnDevice{
{
ID: 1,
RPID: "https://example.com",
RPID: "example.com",
Username: "john",
Description: "Primary",
KID: model.NewBase64([]byte("abc123")),
@ -120,7 +120,7 @@ func TestWebauthnGetUserWithoutDisplayName(t *testing.T) {
},
}, nil)
user, err := getWebAuthnUser(ctx.Ctx, userSession)
user, err := getWebauthnUserByRPID(ctx.Ctx, userSession, "example.com")
require.NoError(t, err)
require.NotNil(t, user)
@ -136,9 +136,9 @@ func TestWebauthnGetUserWithErr(t *testing.T) {
Username: "john",
}
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "john").Return(nil, errors.New("not found"))
ctx.StorageMock.EXPECT().LoadWebauthnDevicesByUsername(ctx.Ctx, "example.com", "john").Return(nil, errors.New("not found"))
user, err := getWebAuthnUser(ctx.Ctx, userSession)
user, err := getWebauthnUserByRPID(ctx.Ctx, userSession, "example.com")
assert.EqualError(t, err, "not found")
assert.Nil(t, user)

View File

@ -527,6 +527,18 @@ func (ctx *AutheliaCtx) GetXOriginalURLOrXForwardedURL() (requestURI *url.URL, e
}
}
// GetOrigin returns the expected origin for requests from this endpoint.
func (ctx *AutheliaCtx) GetOrigin() (origin *url.URL, err error) {
if origin, err = ctx.GetXOriginalURLOrXForwardedURL(); err != nil {
return nil, err
}
origin.Path = ""
origin.RawPath = ""
return origin, nil
}
// IssuerURL returns the expected Issuer.
func (ctx *AutheliaCtx) IssuerURL() (issuerURL *url.URL, err error) {
issuerURL = &url.URL{

View File

@ -435,18 +435,18 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevices(arg0, arg1, arg2 interfac
}
// LoadWebauthnDevicesByUsername mocks base method.
func (m *MockStorage) LoadWebauthnDevicesByUsername(arg0 context.Context, arg1 string) ([]model.WebauthnDevice, error) {
func (m *MockStorage) LoadWebauthnDevicesByUsername(arg0 context.Context, arg1, arg2 string) ([]model.WebauthnDevice, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadWebauthnDevicesByUsername", arg0, arg1)
ret := m.ctrl.Call(m, "LoadWebauthnDevicesByUsername", arg0, arg1, arg2)
ret0, _ := ret[0].([]model.WebauthnDevice)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadWebauthnDevicesByUsername indicates an expected call of LoadWebauthnDevicesByUsername.
func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1, arg2)
}
// RevokeOAuth2Session mocks base method.

View File

@ -0,0 +1,2 @@
DROP INDEX webauthn_devices_lookup_key;
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (rpid, username, description);

View File

@ -0,0 +1,2 @@
DROP INDEX webauthn_devices_lookup_key;
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (rpid, username, description);

View File

@ -0,0 +1,2 @@
DROP INDEX webauthn_devices_lookup_key;
CREATE UNIQUE INDEX webauthn_devices_lookup_key ON webauthn_devices (rpid, username, description);

View File

@ -9,7 +9,7 @@ import (
const (
// This is the latest schema version for the purpose of tests.
LatestVersion = 7
LatestVersion = 8
)
func TestShouldObtainCorrectUpMigrations(t *testing.T) {

View File

@ -44,7 +44,7 @@ type Provider interface {
DeleteWebauthnDevice(ctx context.Context, kid string) (err error)
DeleteWebauthnDeviceByUsername(ctx context.Context, username, description string) (err error)
LoadWebauthnDevices(ctx context.Context, limit, page int) (devices []model.WebauthnDevice, err error)
LoadWebauthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebauthnDevice, err error)
LoadWebauthnDevicesByUsername(ctx context.Context, rpid, username string) (devices []model.WebauthnDevice, err error)
LoadWebauthnDeviceByID(ctx context.Context, id int) (device *model.WebauthnDevice, err error)
SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error)

View File

@ -49,10 +49,11 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlUpsertWebauthnDevice: fmt.Sprintf(queryFmtUpsertWebauthnDevice, tableWebauthnDevices),
sqlSelectWebauthnDevices: fmt.Sprintf(queryFmtSelectWebauthnDevices, tableWebauthnDevices),
sqlSelectWebauthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebauthnDevicesByUsername, tableWebauthnDevices),
sqlSelectWebauthnDevicesByRPIDByUsername: fmt.Sprintf(queryFmtSelectWebauthnDevicesByRPIDByUsername, tableWebauthnDevices),
sqlSelectWebauthnDeviceByID: fmt.Sprintf(queryFmtSelectWebauthnDeviceByID, tableWebauthnDevices),
sqlUpdateWebauthnDeviceDescriptionByUsernameAndID: fmt.Sprintf(queryFmtUpdateUpdateWebauthnDeviceDescriptionByUsernameAndID, tableWebauthnDevices),
sqlUpdateWebauthnDevicePublicKey: fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices),
sqlUpdateWebauthnDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername, tableWebauthnDevices),
sqlUpdateWebauthnDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKeyByUsername, tableWebauthnDevices),
sqlUpdateWebauthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignIn, tableWebauthnDevices),
sqlUpdateWebauthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignInByUsername, tableWebauthnDevices),
sqlDeleteWebauthnDevice: fmt.Sprintf(queryFmtDeleteWebauthnDevice, tableWebauthnDevices),
@ -163,10 +164,11 @@ type SQLProvider struct {
sqlUpdateTOTPConfigRecordSignInByUsername string
// Table: webauthn_devices.
sqlUpsertWebauthnDevice string
sqlSelectWebauthnDevices string
sqlSelectWebauthnDevicesByUsername string
sqlSelectWebauthnDeviceByID string
sqlUpsertWebauthnDevice string
sqlSelectWebauthnDevices string
sqlSelectWebauthnDevicesByUsername string
sqlSelectWebauthnDevicesByRPIDByUsername string
sqlSelectWebauthnDeviceByID string
sqlUpdateWebauthnDeviceDescriptionByUsernameAndID string
sqlUpdateWebauthnDevicePublicKey string
@ -940,8 +942,15 @@ func (p *SQLProvider) LoadWebauthnDeviceByID(ctx context.Context, id int) (devic
}
// LoadWebauthnDevicesByUsername loads all webauthn devices registration for a given username.
func (p *SQLProvider) LoadWebauthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebauthnDevice, err error) {
if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevicesByUsername, username); err != nil {
func (p *SQLProvider) LoadWebauthnDevicesByUsername(ctx context.Context, rpid, username string) (devices []model.WebauthnDevice, err error) {
switch len(rpid) {
case 0:
err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevicesByUsername, username)
default:
err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevicesByRPIDByUsername, rpid, username)
}
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return devices, ErrNoWebauthnDevice
}

View File

@ -134,6 +134,11 @@ const (
FROM %s
WHERE username = ?;`
queryFmtSelectWebauthnDevicesByRPIDByUsername = `
SELECT id, created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
FROM %s
WHERE rpid = ? AND username = ?;`
queryFmtSelectWebauthnDeviceByID = `
SELECT id, created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning
FROM %s
@ -144,7 +149,7 @@ const (
SET public_key = ?
WHERE id = ?;`
queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername = `
queryFmtUpdateWebauthnDevicePublicKeyByUsername = `
UPDATE %s
SET public_key = ?
WHERE username = ? AND kid = ?;`
@ -175,8 +180,8 @@ const (
queryFmtUpsertWebauthnDevicePostgreSQL = `
INSERT INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (username, description)
DO UPDATE SET created_at = $1, last_used_at = $2, rpid = $3, kid = $6, public_key = $7, attestation_type = $8, transport = $9, aaguid = $10, sign_count = $11, clone_warning = $12;`
ON CONFLICT (rpid, username, description)
DO UPDATE SET created_at = $1, last_used_at = $2, kid = $6, public_key = $7, attestation_type = $8, transport = $9, aaguid = $10, sign_count = $11, clone_warning = $12;`
queryFmtDeleteWebauthnDevice = `
DELETE FROM %s

View File

@ -244,14 +244,16 @@ func (p *SQLProvider) schemaMigrateLock(ctx context.Context, conn SQLXConnection
}
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, conn SQLXConnection, migration model.SchemaMigration) (err error) {
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
if migration.Query != "" {
if _, err = conn.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
if migration.Version == 1 && migration.Up {
// Add the schema encryption value if upgrading to v1.
if err = p.setNewEncryptionCheckValue(ctx, conn, &p.key); err != nil {
return err
}
}
}

View File

@ -82,7 +82,7 @@
"^.+\\.(css|png|svg)$": "jest-transform-stub"
},
"transformIgnorePatterns": [
"[/\\\\]node_modules[/\\\\](?!(\\.pnpm[/\\\\])?(@fortawesome[+/\\\\]fontawesome-svg-core|@simplewebauthn[+/\\\\]browser))"
"[/\\\\]node_modules[/\\\\](?!(\\.pnpm[/\\\\])?(@simplewebauthn[+/\\\\]browser)).+\\.(js|jsx|cjs|ts|tsx)$"
],
"moduleNameMapper": {
"^@root/(.*)$": [

View File

@ -144,12 +144,13 @@ export async function startWebauthnRegistration(options: PublicKeyCredentialCrea
};
try {
console.log(JSON.stringify(options));
result.response = await startRegistration(options);
} catch (e) {
const exception = e as DOMException;
if (exception !== undefined) {
result.result = getAttestationResultFromDOMException(exception);
console.error(exception);
return result;
} else {
console.error(`Unhandled exception occurred during WebAuthn attestation: ${e}`);
@ -175,6 +176,8 @@ export async function getAuthenticationResult(options: PublicKeyCredentialReques
if (exception !== undefined) {
result.result = getAssertionResultFromDOMException(exception, options);
console.error(exception);
return result;
} else {
console.error(`Unhandled exception occurred during WebAuthn authentication: ${e}`);

View File

@ -23,7 +23,9 @@ import { finishRegistration, getAttestationCreationOptions, startWebauthnRegistr
const steps = ["Confirm device", "Choose name"];
interface Props {}
interface Props {
est: AuthenticatorSelectionCriteria;
}
const RegisterWebauthn = function (props: Props) {
const [state, setState] = useState(WebauthnTouchState.WaitTouch);
@ -69,12 +71,16 @@ const RegisterWebauthn = function (props: Props) {
return;
}
console.log("start registration");
try {
setState(WebauthnTouchState.WaitTouch);
setActiveStep(0);
const res = await startWebauthnRegistration(options);
console.log("got response", res.result);
if (res.result === AttestationResult.Success) {
if (res.response == null) {
throw new Error("Attestation request succeeded but credential is empty");

View File

@ -40,6 +40,7 @@ export default function WebauthnDevicesStack(props: Props) {
<Stack spacing={3}>
{devices.map((x, idx) => (
<WebauthnDeviceItem
key={idx}
index={idx}
device={x}
handleDeviceEdit={handleEdit}