diff --git a/api/openapi.yml b/api/openapi.yml index 1ee42c610..74f4dad16 100644 --- a/api/openapi.yml +++ b/api/openapi.yml @@ -317,6 +317,24 @@ paths: description: Forbidden security: - authelia_auth: [] + post: + tags: + - User Information + summary: User Configuration + description: > + The user info endpoint provides detailed information including a users display name, preferred and registered + second factor method(s). The POST method also ensures the preferred method is configured correctly. + responses: + "200": + description: Successful Operation + content: + application/json: + schema: + $ref: '#/components/schemas/handlers.UserInfo' + "403": + description: Forbidden + security: + - authelia_auth: [] /api/user/info/totp: get: tags: diff --git a/internal/authentication/const.go b/internal/authentication/const.go index 9bc072a44..0dbe0674c 100644 --- a/internal/authentication/const.go +++ b/internal/authentication/const.go @@ -16,15 +16,6 @@ const ( TwoFactor Level = iota ) -const ( - // TOTP Method using Time-Based One-Time Password applications like Google Authenticator. - TOTP = "totp" - // Webauthn Method using Webauthn devices like YubiKeys. - Webauthn = "webauthn" - // Push Method using Duo application to receive push notifications. - Push = "mobile_push" -) - const ( ldapSupportedExtensionAttribute = "supportedExtension" ldapOIDPasswdModifyExtension = "1.3.6.1.4.1.4203.1.11.1" // http://oidref.com/1.3.6.1.4.1.4203.1.11.1 @@ -36,9 +27,6 @@ const ( ldapPlaceholderUsername = "{username}" ) -// PossibleMethods is the set of all possible 2FA methods. -var PossibleMethods = []string{TOTP, Webauthn, Push} - // CryptAlgo the crypt representation of an algorithm used in the prefix of the hash. type CryptAlgo string diff --git a/internal/handlers/handler_configuration.go b/internal/handlers/handler_configuration.go index 5474b861c..d053b56e5 100644 --- a/internal/handlers/handler_configuration.go +++ b/internal/handlers/handler_configuration.go @@ -1,7 +1,6 @@ package handlers import ( - "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/middlewares" ) @@ -12,17 +11,7 @@ func ConfigurationGet(ctx *middlewares.AutheliaCtx) { } if ctx.Providers.Authorizer.IsSecondFactorEnabled() { - if !ctx.Configuration.TOTP.Disable { - body.AvailableMethods = append(body.AvailableMethods, authentication.TOTP) - } - - if !ctx.Configuration.Webauthn.Disable { - body.AvailableMethods = append(body.AvailableMethods, authentication.Webauthn) - } - - if ctx.Configuration.DuoAPI != nil { - body.AvailableMethods = append(body.AvailableMethods, authentication.Push) - } + body.AvailableMethods = ctx.AvailableSecondFactorMethods() } ctx.Logger.Tracef("Available methods are %s", body.AvailableMethods) diff --git a/internal/handlers/handler_user_info.go b/internal/handlers/handler_user_info.go index d3c78cafc..e11a098c9 100644 --- a/internal/handlers/handler_user_info.go +++ b/internal/handlers/handler_user_info.go @@ -1,16 +1,61 @@ package handlers import ( + "database/sql" + "errors" "fmt" "strings" - "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/utils" ) -// UserInfoGet get the info related to the user identified by the session. -func UserInfoGet(ctx *middlewares.AutheliaCtx) { +// UserInfoPOST handles setting up info for users if necessary when they login. +func UserInfoPOST(ctx *middlewares.AutheliaCtx) { + userSession := ctx.GetSession() + + var ( + userInfo model.UserInfo + err error + ) + + if _, err = ctx.Providers.StorageProvider.LoadPreferred2FAMethod(ctx, userSession.Username); err != nil { + if errors.Is(err, sql.ErrNoRows) { + if err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, ""); err != nil { + ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed) + } + } else { + ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed) + } + } + + if userInfo, err = ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username); err != nil { + ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed) + return + } + + var ( + changed bool + ) + + if changed = userInfo.SetDefaultPreferred2FAMethod(ctx.AvailableSecondFactorMethods()); changed { + if err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, userInfo.Method); err != nil { + ctx.Error(fmt.Errorf("unable to save user two factor method: %v", err), messageOperationFailed) + return + } + } + + userInfo.DisplayName = userSession.DisplayName + + err = ctx.SetJSONBody(userInfo) + if err != nil { + ctx.Logger.Errorf("Unable to set user info response in body: %s", err) + } +} + +// UserInfoGET get the info related to the user identified by the session. +func UserInfoGET(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username) @@ -37,8 +82,8 @@ func MethodPreferencePost(ctx *middlewares.AutheliaCtx) { return } - if !utils.IsStringInSlice(bodyJSON.Method, authentication.PossibleMethods) { - ctx.Error(fmt.Errorf("unknown method '%s', it should be one of %s", bodyJSON.Method, strings.Join(authentication.PossibleMethods, ", ")), messageOperationFailed) + if !utils.IsStringInSlice(bodyJSON.Method, ctx.AvailableSecondFactorMethods()) { + ctx.Error(fmt.Errorf("unknown or unavailable method '%s', it should be one of %s", bodyJSON.Method, strings.Join(ctx.AvailableSecondFactorMethods(), ", ")), messageOperationFailed) return } diff --git a/internal/handlers/handler_user_info_test.go b/internal/handlers/handler_user_info_test.go index b3b3ec646..d2fd256a7 100644 --- a/internal/handlers/handler_user_info_test.go +++ b/internal/handlers/handler_user_info_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/model" ) @@ -41,7 +42,17 @@ type expectedResponse struct { err error } -func TestMethodSetToU2F(t *testing.T) { +type expectedResponseAlt struct { + description string + + db model.UserInfo + api *model.UserInfo + loadErr error + saveErr error + config *schema.Configuration +} + +func TestUserInfoEndpoint_SetCorrectMethod(t *testing.T) { expectedResponses := []expectedResponse{ { db: model.UserInfo{ @@ -89,6 +100,9 @@ func TestMethodSetToU2F(t *testing.T) { } mock := mocks.NewMockAutheliaCtx(t) + + mock.Ctx.Configuration.DuoAPI = &schema.DuoAPIConfiguration{} + // Set the initial user session. userSession := mock.Ctx.GetSession() userSession.Username = testUsername @@ -101,7 +115,7 @@ func TestMethodSetToU2F(t *testing.T) { LoadUserInfo(mock.Ctx, gomock.Eq("john")). Return(resp.db, resp.err) - UserInfoGet(mock.Ctx) + UserInfoGET(mock.Ctx) if resp.err == nil { t.Run("expected status code", func(t *testing.T) { @@ -123,6 +137,207 @@ func TestMethodSetToU2F(t *testing.T) { t.Run("registered totp", func(t *testing.T) { assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP) }) + + t.Run("registered duo", func(t *testing.T) { + assert.Equal(t, resp.api.HasDuo, actualPreferences.HasDuo) + }) + } else { + t.Run("expected status code", func(t *testing.T) { + assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) + }) + + errResponse := mock.GetResponseError(t) + + assert.Equal(t, "KO", errResponse.Status) + assert.Equal(t, "Operation failed.", errResponse.Message) + } + + mock.Close() + } +} + +func TestUserInfoEndpoint_SetDefaultMethod(t *testing.T) { + expectedResponses := []expectedResponseAlt{ + { + description: "should set method to totp by default even when user doesn't have totp configured and no preferred method", + db: model.UserInfo{ + Method: "", + HasTOTP: false, + HasWebauthn: false, + HasDuo: false, + }, + api: &model.UserInfo{ + Method: "totp", + HasTOTP: false, + HasWebauthn: false, + HasDuo: false, + }, + config: &schema.Configuration{ + DuoAPI: &schema.DuoAPIConfiguration{}, + }, + loadErr: nil, + saveErr: nil, + }, + { + description: "should set method to duo by default when user has duo configured and no preferred method", + db: model.UserInfo{ + Method: "", + HasTOTP: false, + HasWebauthn: false, + HasDuo: true, + }, + api: &model.UserInfo{ + Method: "mobile_push", + HasTOTP: false, + HasWebauthn: false, + HasDuo: true, + }, + config: &schema.Configuration{ + DuoAPI: &schema.DuoAPIConfiguration{}, + }, + loadErr: nil, + saveErr: nil, + }, + { + description: "should set method to totp by default when user has duo configured and no preferred method but duo is not enabled", + db: model.UserInfo{ + Method: "", + HasTOTP: false, + HasWebauthn: false, + HasDuo: true, + }, + api: &model.UserInfo{ + Method: "totp", + HasTOTP: false, + HasWebauthn: false, + HasDuo: true, + }, + loadErr: nil, + saveErr: nil, + }, + { + description: "should set method to duo by default when user has duo configured and no preferred method", + db: model.UserInfo{ + Method: "", + HasTOTP: true, + HasWebauthn: true, + HasDuo: true, + }, + api: &model.UserInfo{ + Method: "webauthn", + HasTOTP: true, + HasWebauthn: true, + HasDuo: true, + }, + config: &schema.Configuration{ + TOTP: schema.TOTPConfiguration{ + Disable: true, + }, + DuoAPI: &schema.DuoAPIConfiguration{}, + }, + loadErr: nil, + saveErr: nil, + }, + { + description: "should default new users to totp if all enabled", + db: model.UserInfo{ + Method: "", + HasTOTP: false, + HasWebauthn: false, + HasDuo: false, + }, + api: &model.UserInfo{ + Method: "totp", + HasTOTP: true, + HasWebauthn: true, + HasDuo: true, + }, + config: &schema.Configuration{ + DuoAPI: &schema.DuoAPIConfiguration{}, + }, + loadErr: nil, + saveErr: errors.New("could not save"), + }, + } + + for _, resp := range expectedResponses { + if resp.api == nil { + resp.api = &resp.db + } + + mock := mocks.NewMockAutheliaCtx(t) + + if resp.config != nil { + mock.Ctx.Configuration = *resp.config + } + + // Set the initial user session. + userSession := mock.Ctx.GetSession() + userSession.Username = testUsername + userSession.AuthenticationLevel = 1 + err := mock.Ctx.SaveSession(userSession) + require.NoError(t, err) + + if resp.db.Method == "" { + gomock.InOrder( + mock.StorageMock. + EXPECT(). + LoadPreferred2FAMethod(mock.Ctx, gomock.Eq("john")). + Return("", sql.ErrNoRows), + mock.StorageMock. + EXPECT(). + SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq("")). + Return(resp.saveErr), + mock.StorageMock. + EXPECT(). + LoadUserInfo(mock.Ctx, gomock.Eq("john")). + Return(resp.db, nil), + mock.StorageMock.EXPECT(). + SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq(resp.api.Method)). + Return(resp.saveErr), + ) + } else { + gomock.InOrder( + mock.StorageMock. + EXPECT(). + LoadPreferred2FAMethod(mock.Ctx, gomock.Eq("john")). + Return(resp.db.Method, nil), + mock.StorageMock. + EXPECT(). + LoadUserInfo(mock.Ctx, gomock.Eq("john")). + Return(resp.db, nil), + mock.StorageMock.EXPECT(). + SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq(resp.api.Method)). + Return(resp.saveErr), + ) + } + + UserInfoPOST(mock.Ctx) + + if resp.loadErr == nil && resp.saveErr == nil { + t.Run(fmt.Sprintf("%s/%s", resp.description, "expected status code"), func(t *testing.T) { + assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) + }) + + actualPreferences := model.UserInfo{} + + mock.GetResponseData(t, &actualPreferences) + + t.Run(fmt.Sprintf("%s/%s", resp.description, "expected method"), func(t *testing.T) { + assert.Equal(t, resp.api.Method, actualPreferences.Method) + }) + + t.Run(fmt.Sprintf("%s/%s", resp.description, "registered webauthn"), func(t *testing.T) { + assert.Equal(t, resp.api.HasWebauthn, actualPreferences.HasWebauthn) + }) + + t.Run(fmt.Sprintf("%s/%s", resp.description, "registered totp"), func(t *testing.T) { + assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP) + }) + + t.Run(fmt.Sprintf("%s/%s", resp.description, "registered duo"), func(t *testing.T) { + assert.Equal(t, resp.api.HasDuo, actualPreferences.HasDuo) + }) } else { t.Run("expected status code", func(t *testing.T) { assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) @@ -143,7 +358,7 @@ func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() { LoadUserInfo(s.mock.Ctx, gomock.Eq("john")). Return(model.UserInfo{}, fmt.Errorf("failure")) - UserInfoGet(s.mock.Ctx) + UserInfoGET(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed.") assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message) @@ -205,7 +420,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() { MethodPreferencePost(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed.") - assert.Equal(s.T(), "unknown method 'abc', it should be one of totp, webauthn, mobile_push", s.mock.Hook.LastEntry().Message) + assert.Equal(s.T(), "unknown or unavailable method 'abc', it should be one of totp, webauthn", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level) } diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 03016a688..4b209176b 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -14,6 +14,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/utils" ) @@ -54,6 +55,25 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers) } } +// AvailableSecondFactorMethods returns the available 2FA methods. +func (ctx *AutheliaCtx) AvailableSecondFactorMethods() (methods []string) { + methods = make([]string, 0, 3) + + if !ctx.Configuration.TOTP.Disable { + methods = append(methods, model.SecondFactorMethodTOTP) + } + + if !ctx.Configuration.Webauthn.Disable { + methods = append(methods, model.SecondFactorMethodWebauthn) + } + + if ctx.Configuration.DuoAPI != nil { + methods = append(methods, model.SecondFactorMethodDuo) + } + + return methods +} + // Error reply with an error and display the stack trace in the logs. func (ctx *AutheliaCtx) Error(err error, message string) { ctx.SetJSONError(message) diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go index 5a26f1c78..0dbb29512 100644 --- a/internal/middlewares/authelia_context_test.go +++ b/internal/middlewares/authelia_context_test.go @@ -11,6 +11,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/mocks" + "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/session" ) @@ -115,3 +116,26 @@ func TestShouldDetectNonXHR(t *testing.T) { assert.False(t, mock.Ctx.IsXHR()) } + +func TestShouldReturnCorrectSecondFactorMethods(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + assert.Equal(t, []string{model.SecondFactorMethodTOTP, model.SecondFactorMethodWebauthn}, mock.Ctx.AvailableSecondFactorMethods()) + + mock.Ctx.Configuration.DuoAPI = &schema.DuoAPIConfiguration{} + + assert.Equal(t, []string{model.SecondFactorMethodTOTP, model.SecondFactorMethodWebauthn, model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods()) + + mock.Ctx.Configuration.TOTP.Disable = true + + assert.Equal(t, []string{model.SecondFactorMethodWebauthn, model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods()) + + mock.Ctx.Configuration.Webauthn.Disable = true + + assert.Equal(t, []string{model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods()) + + mock.Ctx.Configuration.DuoAPI = nil + + assert.Equal(t, []string{}, mock.Ctx.AvailableSecondFactorMethods()) +} diff --git a/internal/model/const.go b/internal/model/const.go index d100742b9..efdf6ef38 100644 --- a/internal/model/const.go +++ b/internal/model/const.go @@ -6,3 +6,14 @@ const ( errFmtScanInvalidType = "cannot scan model type '%T' from type '%T' with value '%v'" errFmtScanInvalidTypeErr = "cannot scan model type '%T' from type '%T' with value '%v': %w" ) + +const ( + // SecondFactorMethodTOTP method using Time-Based One-Time Password applications like Google Authenticator. + SecondFactorMethodTOTP = "totp" + + // SecondFactorMethodWebauthn method using Webauthn devices like YubiKey's. + SecondFactorMethodWebauthn = "webauthn" + + // SecondFactorMethodDuo method using Duo application to receive push notifications. + SecondFactorMethodDuo = "mobile_push" +) diff --git a/internal/model/user_info.go b/internal/model/user_info.go index c43df85ff..50fc92945 100644 --- a/internal/model/user_info.go +++ b/internal/model/user_info.go @@ -1,5 +1,9 @@ package model +import ( + "github.com/authelia/authelia/v4/internal/utils" +) + // UserInfo represents the user information required by the web UI. type UserInfo struct { // The users display name. @@ -17,3 +21,38 @@ type UserInfo struct { // True if a duo device has been configured as the preferred. HasDuo bool `db:"has_duo" json:"has_duo" valid:"required"` } + +// SetDefaultPreferred2FAMethod configures the default method based on what is configured as available and the users available methods. +func (i *UserInfo) SetDefaultPreferred2FAMethod(methods []string) (changed bool) { + if len(methods) == 0 { + // No point attempting to change the method if no methods are available. + return false + } + + before := i.Method + + totp, webauthn, duo := utils.IsStringInSlice(SecondFactorMethodTOTP, methods), utils.IsStringInSlice(SecondFactorMethodWebauthn, methods), utils.IsStringInSlice(SecondFactorMethodDuo, methods) + + if i.Method != "" && !utils.IsStringInSlice(i.Method, methods) { + i.Method = "" + } + + if i.Method == "" { + switch { + case i.HasTOTP && totp: + i.Method = SecondFactorMethodTOTP + case i.HasWebauthn && webauthn: + i.Method = SecondFactorMethodWebauthn + case i.HasDuo && duo: + i.Method = SecondFactorMethodDuo + case totp: + i.Method = SecondFactorMethodTOTP + case webauthn: + i.Method = SecondFactorMethodWebauthn + case duo: + i.Method = SecondFactorMethodDuo + } + } + + return before != i.Method +} diff --git a/internal/model/user_info_test.go b/internal/model/user_info_test.go new file mode 100644 index 000000000..cefe9ac7f --- /dev/null +++ b/internal/model/user_info_test.go @@ -0,0 +1,222 @@ +package model + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUserInfo_SetDefaultMethod_ShouldConfigureConfigDefault(t *testing.T) { + none := "none" + + testName := func(i int, have UserInfo, availableMethods []string) string { + method := have.Method + + if method == "" { + method = none + } + + has := "" + + if have.HasTOTP || have.HasDuo || have.HasWebauthn { + has += " has" + + if have.HasTOTP { + has += " " + SecondFactorMethodTOTP + } + + if have.HasDuo { + has += " " + SecondFactorMethodDuo + } + + if have.HasWebauthn { + has += " " + SecondFactorMethodWebauthn + } + } + + available := none + if len(availableMethods) != 0 { + available = strings.Join(availableMethods, " ") + } + + return fmt.Sprintf("%d/method %s%s/available methods %s", i+1, method, has, available) + } + + testCases := []struct { + have UserInfo + availableMethods []string + changed bool + want UserInfo + }{ + { + have: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: true, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: []string{SecondFactorMethodWebauthn, SecondFactorMethodDuo}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: true, + HasTOTP: true, + HasWebauthn: true, + }, + }, + { + have: UserInfo{ + HasDuo: true, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: []string{SecondFactorMethodTOTP, SecondFactorMethodWebauthn, SecondFactorMethodDuo}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: true, + HasTOTP: true, + HasWebauthn: true, + }, + }, + { + have: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: true, + HasTOTP: false, + HasWebauthn: false, + }, + availableMethods: []string{SecondFactorMethodTOTP}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: true, + HasTOTP: false, + HasWebauthn: false, + }, + }, + { + have: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + availableMethods: []string{SecondFactorMethodTOTP}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + }, + { + have: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + availableMethods: []string{SecondFactorMethodWebauthn}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + }, + { + have: UserInfo{ + Method: SecondFactorMethodTOTP, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + availableMethods: []string{SecondFactorMethodDuo}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodDuo, + HasDuo: false, + HasTOTP: false, + HasWebauthn: false, + }, + }, + { + have: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: []string{SecondFactorMethodTOTP, SecondFactorMethodWebauthn, SecondFactorMethodDuo}, + changed: false, + want: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + }, + { + have: UserInfo{ + Method: "", + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: []string{SecondFactorMethodWebauthn, SecondFactorMethodDuo}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodWebauthn, + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + }, + { + have: UserInfo{ + Method: "", + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: []string{SecondFactorMethodDuo}, + changed: true, + want: UserInfo{ + Method: SecondFactorMethodDuo, + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + }, + { + have: UserInfo{ + Method: "", + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + availableMethods: nil, + changed: false, + want: UserInfo{ + Method: "", + HasDuo: false, + HasTOTP: true, + HasWebauthn: true, + }, + }, + } + + for i, tc := range testCases { + t.Run(testName(i, tc.have, tc.availableMethods), func(t *testing.T) { + changed := tc.have.SetDefaultPreferred2FAMethod(tc.availableMethods) + + assert.Equal(t, tc.changed, changed) + assert.Equal(t, tc.want, tc.have) + }) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 221ed5cdb..204804d14 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -86,7 +86,9 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr // Information about the user. r.GET("/api/user/info", autheliaMiddleware( - middlewares.RequireFirstFactor(handlers.UserInfoGet))) + middlewares.RequireFirstFactor(handlers.UserInfoGET))) + r.POST("/api/user/info", autheliaMiddleware( + middlewares.RequireFirstFactor(handlers.UserInfoPOST))) r.POST("/api/user/info/2fa_method", autheliaMiddleware( middlewares.RequireFirstFactor(handlers.MethodPreferencePost))) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 7e4922a2f..126c4ddb1 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -11,7 +11,6 @@ import ( "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" - "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/model" @@ -205,7 +204,7 @@ func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username strin case err == nil: return method, nil case errors.Is(err, sql.ErrNoRows): - return "", nil + return "", sql.ErrNoRows default: return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) } @@ -216,17 +215,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) switch { - case err == nil: - return info, nil - case errors.Is(err, sql.ErrNoRows): - if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]); err != nil { - return model.UserInfo{}, fmt.Errorf("error upserting preferred two factor method while selecting user info for user '%s': %w", username, err) - } - - if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username); err != nil { - return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) - } - + case err == nil, errors.Is(err, sql.ErrNoRows): return info, nil default: return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) diff --git a/web/src/hooks/UserInfo.ts b/web/src/hooks/UserInfo.ts index 88eb258a8..56c433b9d 100644 --- a/web/src/hooks/UserInfo.ts +++ b/web/src/hooks/UserInfo.ts @@ -1,6 +1,6 @@ import { useRemoteCall } from "@hooks/RemoteCall"; -import { getUserInfo } from "@services/UserInfo"; +import { postUserInfo } from "@services/UserInfo"; -export function useUserInfo() { - return useRemoteCall(getUserInfo, []); +export function useUserInfoPOST() { + return useRemoteCall(postUserInfo, []); } diff --git a/web/src/services/UserInfo.ts b/web/src/services/UserInfo.ts index 08118c9ae..0688fe28b 100644 --- a/web/src/services/UserInfo.ts +++ b/web/src/services/UserInfo.ts @@ -1,7 +1,7 @@ import { SecondFactorMethod } from "@models/Methods"; import { UserInfo } from "@models/UserInfo"; import { UserInfo2FAMethodPath, UserInfoPath } from "@services/Api"; -import { Get, PostWithOptionalResponse } from "@services/Client"; +import { Post, PostWithOptionalResponse } from "@services/Client"; export type Method2FA = "webauthn" | "totp" | "mobile_push"; @@ -39,8 +39,8 @@ export function toString(method: SecondFactorMethod): Method2FA { } } -export async function getUserInfo(): Promise { - const res = await Get(UserInfoPath); +export async function postUserInfo(): Promise { + const res = await Post(UserInfoPath); return { ...res, method: toEnum(res.method) }; } diff --git a/web/src/views/LoginPortal/LoginPortal.tsx b/web/src/views/LoginPortal/LoginPortal.tsx index 3aa224465..ea507c74a 100644 --- a/web/src/views/LoginPortal/LoginPortal.tsx +++ b/web/src/views/LoginPortal/LoginPortal.tsx @@ -16,7 +16,7 @@ import { useRedirectionURL } from "@hooks/RedirectionURL"; import { useRedirector } from "@hooks/Redirector"; import { useRequestMethod } from "@hooks/RequestMethod"; import { useAutheliaState } from "@hooks/State"; -import { useUserInfo } from "@hooks/UserInfo"; +import { useUserInfoPOST } from "@hooks/UserInfo"; import { SecondFactorMethod } from "@models/Methods"; import { checkSafeRedirection } from "@services/SafeRedirection"; import { AuthenticationLevel } from "@services/State"; @@ -44,7 +44,7 @@ const LoginPortal = function (props: Props) { const redirector = useRedirector(); const [state, fetchState, , fetchStateError] = useAutheliaState(); - const [userInfo, fetchUserInfo, , fetchUserInfoError] = useUserInfo(); + const [userInfo, fetchUserInfo, , fetchUserInfoError] = useUserInfoPOST(); const [configuration, fetchConfiguration, , fetchConfigurationError] = useConfiguration(); const redirect = useCallback((url: string) => navigate(url), [navigate]); diff --git a/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx b/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx index 9740a89ef..06387ef7e 100644 --- a/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx +++ b/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx @@ -85,22 +85,26 @@ const SecondFactorForm = function (props: Props) { return ( - setMethodSelectionOpen(false)} - onClick={handleMethodSelected} - /> + {props.configuration.available_methods.size > 1 ? ( + setMethodSelectionOpen(false)} + onClick={handleMethodSelected} + /> + ) : null} - {" | "} - + {props.configuration.available_methods.size > 1 ? " | " : null} + {props.configuration.available_methods.size > 1 ? ( + + ) : null}