diff --git a/cmd/authelia-scripts/authelia-scripts b/cmd/authelia-scripts/authelia-scripts index d9f522641..2bf8bb3ad 100755 --- a/cmd/authelia-scripts/authelia-scripts +++ b/cmd/authelia-scripts/authelia-scripts @@ -1,3 +1,3 @@ #!/bin/bash -go run cmd/authelia-scripts/*.go $* \ No newline at end of file +go run -tags migration cmd/authelia-scripts/*.go $* \ No newline at end of file diff --git a/cmd/authelia-scripts/main.go b/cmd/authelia-scripts/main.go index 48eeefc2d..938a0941a 100755 --- a/cmd/authelia-scripts/main.go +++ b/cmd/authelia-scripts/main.go @@ -80,6 +80,14 @@ var Commands = []AutheliaCommandDefinition{ Short: "Run unit tests", Func: RunUnitTest, }, + AutheliaCommandDefinition{ + Name: "migrate", + Short: "Migrate data from v3 to v4", + SubCommands: CobraCommands{ + MigrateLocalCmd, + MigrateMongoCmd, + }, + }, } func levelStringToLevel(level string) log.Level { diff --git a/cmd/authelia-scripts/migration.go b/cmd/authelia-scripts/migration.go new file mode 100644 index 000000000..273241afe --- /dev/null +++ b/cmd/authelia-scripts/migration.go @@ -0,0 +1,67 @@ +package main + +import ( + "encoding/base64" + "strings" + + "github.com/clems4ever/authelia/configuration" + "github.com/clems4ever/authelia/storage" +) + +// TOTPSecretsV3 one entry of TOTP secrets in v3 +type TOTPSecretsV3 struct { + UserID string `json:"userId"` + Secret struct { + Base32 string `json:"base32"` + } `json:"secret"` +} + +// U2FDeviceHandleV3 one entry of U2F device handle in v3 +type U2FDeviceHandleV3 struct { + UserID string `json:"userId"` + Registration struct { + KeyHandle string `json:"keyHandle"` + PublicKey string `json:"publicKey"` + } `json:"registration"` +} + +// PreferencesV3 one entry of preferences in v3 +type PreferencesV3 struct { + UserID string `json:"userId"` + Method string `json:"method"` +} + +// AuthenticationTraceV3 one authentication trace in v3 +type AuthenticationTraceV3 struct { + UserID string `json:"userId"` + Successful bool `json:"isAuthenticationSuccessful"` + Date struct { + Date int64 `json:"$$date"` + } `json:"date"` +} + +func decodeWebsafeBase64(s string) ([]byte, error) { + s = strings.ReplaceAll(s, "_", "/") + s = strings.ReplaceAll(s, "-", "+") + + for len(s)%4 != 0 { + s += "=" + } + + return base64.StdEncoding.DecodeString(s) +} + +func createDBProvider(configurationPath string) storage.Provider { + config, _ := configuration.Read(configurationPath) + + var dbProvider storage.Provider + if config.Storage.Local != nil { + dbProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) + } else if config.Storage.MySQL != nil { + dbProvider = storage.NewMySQLProvider(*config.Storage.MySQL) + } else if config.Storage.PostgreSQL != nil { + dbProvider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL) + } + + return dbProvider +} diff --git a/cmd/authelia-scripts/migration_local.go b/cmd/authelia-scripts/migration_local.go new file mode 100644 index 000000000..8192cf3a0 --- /dev/null +++ b/cmd/authelia-scripts/migration_local.go @@ -0,0 +1,152 @@ +package main + +import ( + "bufio" + "encoding/json" + "log" + "os" + "path" + "time" + + "github.com/clems4ever/authelia/models" + "github.com/clems4ever/authelia/storage" + "github.com/spf13/cobra" +) + +var configurationPath string +var localDatabasePath string + +// MigrateLocalCmd migration command +var MigrateLocalCmd = &cobra.Command{ + Use: "localdb", + Short: "Migrate data from v3 local database into database configured in v4 configuration file", + Run: migrateLocal, +} + +func init() { + MigrateLocalCmd.PersistentFlags().StringVarP(&localDatabasePath, "db-path", "p", "", "The path to the v3 local database") + MigrateLocalCmd.MarkPersistentFlagRequired("db-path") + + MigrateLocalCmd.PersistentFlags().StringVarP(&configurationPath, "config", "c", "", "The configuration file of Authelia v4") + MigrateLocalCmd.MarkPersistentFlagRequired("config") +} + +// migrateLocal data from v3 to v4 +func migrateLocal(cmd *cobra.Command, args []string) { + dbProvider := createDBProvider(configurationPath) + + migrateLocalTOTPSecret(dbProvider) + migrateLocalU2FSecret(dbProvider) + migrateLocalPreferences(dbProvider) + migrateLocalAuthenticationTraces(dbProvider) + // We don't need to migrate identity tokens + + log.Println("Migration done!") +} + +func migrateLocalTOTPSecret(dbProvider storage.Provider) { + file, err := os.Open(path.Join(localDatabasePath, "totp_secrets")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() { + data := scanner.Text() + + entry := TOTPSecretsV3{} + json.Unmarshal([]byte(data), &entry) + err := dbProvider.SaveTOTPSecret(entry.UserID, entry.Secret.Base32) + + if err != nil { + log.Fatal(err) + } + } +} + +func migrateLocalU2FSecret(dbProvider storage.Provider) { + file, err := os.Open(path.Join(localDatabasePath, "u2f_registrations")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() { + data := scanner.Text() + + entry := U2FDeviceHandleV3{} + json.Unmarshal([]byte(data), &entry) + + kH, err := decodeWebsafeBase64(entry.Registration.KeyHandle) + + if err != nil { + log.Fatal(err) + } + + pK, err := decodeWebsafeBase64(entry.Registration.PublicKey) + + if err != nil { + log.Fatal(err) + } + + err = dbProvider.SaveU2FDeviceHandle(entry.UserID, kH, pK) + + if err != nil { + log.Fatal(err) + } + } +} + +func migrateLocalPreferences(dbProvider storage.Provider) { + file, err := os.Open(path.Join(localDatabasePath, "prefered_2fa_method")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() { + data := scanner.Text() + + entry := PreferencesV3{} + json.Unmarshal([]byte(data), &entry) + err := dbProvider.SavePrefered2FAMethod(entry.UserID, entry.Method) + + if err != nil { + log.Fatal(err) + } + } +} + +func migrateLocalAuthenticationTraces(dbProvider storage.Provider) { + file, err := os.Open(path.Join(localDatabasePath, "authentication_traces")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + for scanner.Scan() { + data := scanner.Text() + + entry := AuthenticationTraceV3{} + json.Unmarshal([]byte(data), &entry) + + attempt := models.AuthenticationAttempt{ + Username: entry.UserID, + Successful: entry.Successful, + Time: time.Unix(entry.Date.Date/1000.0, 0), + } + err := dbProvider.AppendAuthenticationLog(attempt) + + if err != nil { + log.Fatal(err) + } + } +} diff --git a/cmd/authelia-scripts/migration_mongo.go b/cmd/authelia-scripts/migration_mongo.go new file mode 100644 index 000000000..daadc59f0 --- /dev/null +++ b/cmd/authelia-scripts/migration_mongo.go @@ -0,0 +1,184 @@ +package main + +import ( + "context" + "log" + "time" + + "github.com/clems4ever/authelia/models" + "github.com/clems4ever/authelia/storage" + "github.com/spf13/cobra" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +var mongoURL string +var mongoDatabase string + +// MigrateMongoCmd migration command +var MigrateMongoCmd = &cobra.Command{ + Use: "mongo", + Short: "Migrate data from v3 mongo database into database configured in v4 configuration file", + Run: migrateMongo, +} + +func init() { + MigrateMongoCmd.PersistentFlags().StringVar(&mongoURL, "url", "", "The address to the mongo server") + MigrateMongoCmd.MarkPersistentFlagRequired("url") + + MigrateMongoCmd.PersistentFlags().StringVar(&mongoDatabase, "database", "", "The mongo database") + MigrateMongoCmd.MarkPersistentFlagRequired("database") + + MigrateMongoCmd.PersistentFlags().StringVarP(&configurationPath, "config", "c", "", "The configuration file of Authelia v4") + MigrateMongoCmd.MarkPersistentFlagRequired("config") +} + +func migrateMongo(cmd *cobra.Command, args []string) { + dbProvider := createDBProvider(configurationPath) + client, err := mongo.NewClient(options.Client().ApplyURI(mongoURL)) + + if err != nil { + log.Fatal(err) + } + + err = client.Connect(context.Background()) + + if err != nil { + log.Fatal(err) + } + + db := client.Database(mongoDatabase) + + migrateMongoU2FDevices(db, dbProvider) + migrateMongoTOTPDevices(db, dbProvider) + migrateMongoPreferences(db, dbProvider) + + log.Println("Migration done!") +} + +func migrateMongoU2FDevices(db *mongo.Database, dbProvider storage.Provider) { + u2fCollection := db.Collection("u2f_registrations") + + cur, err := u2fCollection.Find(context.Background(), bson.D{}) + if err != nil { + log.Fatal(err) + } + defer cur.Close(context.Background()) + + for cur.Next(context.Background()) { + var result U2FDeviceHandleV3 + err := cur.Decode(&result) + if err != nil { + log.Fatal(err) + } + + kH, err := decodeWebsafeBase64(result.Registration.KeyHandle) + + if err != nil { + log.Fatal(err) + } + + pK, err := decodeWebsafeBase64(result.Registration.PublicKey) + + if err != nil { + log.Fatal(err) + } + + err = dbProvider.SaveU2FDeviceHandle(result.UserID, kH, pK) + + if err != nil { + log.Fatal(err) + } + } + if err := cur.Err(); err != nil { + log.Fatal(err) + } +} + +func migrateMongoTOTPDevices(db *mongo.Database, dbProvider storage.Provider) { + u2fCollection := db.Collection("totp_secrets") + + cur, err := u2fCollection.Find(context.Background(), bson.D{}) + if err != nil { + log.Fatal(err) + } + defer cur.Close(context.Background()) + + for cur.Next(context.Background()) { + var result TOTPSecretsV3 + err := cur.Decode(&result) + if err != nil { + log.Fatal(err) + } + + err = dbProvider.SaveTOTPSecret(result.UserID, result.Secret.Base32) + + if err != nil { + log.Fatal(err) + } + } + if err := cur.Err(); err != nil { + log.Fatal(err) + } +} + +func migrateMongoPreferences(db *mongo.Database, dbProvider storage.Provider) { + u2fCollection := db.Collection("prefered_2fa_method") + + cur, err := u2fCollection.Find(context.Background(), bson.D{}) + if err != nil { + log.Fatal(err) + } + defer cur.Close(context.Background()) + + for cur.Next(context.Background()) { + var result PreferencesV3 + err := cur.Decode(&result) + if err != nil { + log.Fatal(err) + } + + err = dbProvider.SavePrefered2FAMethod(result.UserID, result.Method) + + if err != nil { + log.Fatal(err) + } + } + if err := cur.Err(); err != nil { + log.Fatal(err) + } +} + +func migrateMongoAuthenticationTraces(db *mongo.Database, dbProvider storage.Provider) { + u2fCollection := db.Collection("authentication_traces") + + cur, err := u2fCollection.Find(context.Background(), bson.D{}) + if err != nil { + log.Fatal(err) + } + defer cur.Close(context.Background()) + + for cur.Next(context.Background()) { + var result AuthenticationTraceV3 + err := cur.Decode(&result) + if err != nil { + log.Fatal(err) + } + + attempt := models.AuthenticationAttempt{ + Username: result.UserID, + Successful: result.Successful, + Time: time.Unix(result.Date.Date/1000.0, 0), + } + + err = dbProvider.AppendAuthenticationLog(attempt) + + if err != nil { + log.Fatal(err) + } + } + if err := cur.Err(); err != nil { + log.Fatal(err) + } +} diff --git a/example/compose/authelia/docker-compose.backend.yml b/example/compose/authelia/docker-compose.backend.yml index 15e3f3a58..78c60340c 100644 --- a/example/compose/authelia/docker-compose.backend.yml +++ b/example/compose/authelia/docker-compose.backend.yml @@ -14,7 +14,5 @@ services: environment: - SUITE_PATH=${SUITE_PATH} - ENVIRONMENT=dev - ports: - - 9091:9091 networks: - authelianet diff --git a/example/compose/authelia/docker-compose.frontend.yml b/example/compose/authelia/docker-compose.frontend.yml index c8ea59838..6515e1073 100644 --- a/example/compose/authelia/docker-compose.frontend.yml +++ b/example/compose/authelia/docker-compose.frontend.yml @@ -8,7 +8,5 @@ services: working_dir: /app volumes: - "./client:/app" - ports: - - 3000:3000 networks: - authelianet diff --git a/go.mod b/go.mod index d51f20c0c..53c9f76dd 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,9 @@ require ( github.com/fasthttp/router v0.5.2 github.com/fasthttp/session v1.1.3 github.com/go-sql-driver/mysql v1.4.1 + github.com/go-stack/stack v1.8.0 // indirect github.com/golang/mock v1.3.1 + github.com/golang/snappy v0.0.1 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/pty v1.1.8 // indirect github.com/lib/pq v1.2.0 @@ -28,8 +30,12 @@ require ( github.com/spf13/cobra v0.0.5 github.com/stretchr/testify v1.4.0 github.com/tebeka/selenium v0.9.9 + github.com/tidwall/pretty v1.0.0 // indirect github.com/tstranex/u2f v1.0.0 github.com/valyala/fasthttp v1.6.0 + github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c // indirect + github.com/xdg/stringprep v1.0.0 // indirect + go.mongodb.org/mongo-driver v1.1.3 google.golang.org/appengine v1.6.5 // indirect gopkg.in/ldap.v3 v3.1.0 gopkg.in/yaml.v2 v2.2.4 diff --git a/go.sum b/go.sum index 9559b1e43..2c2cbc142 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,8 @@ github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDA github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -58,6 +60,8 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -153,6 +157,8 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/tebeka/selenium v0.9.9 h1:cNziB+etNgyH/7KlNI7RMC1ua5aH1+5wUlFQyzeMh+w= github.com/tebeka/selenium v0.9.9/go.mod h1:5Fr8+pUvU6B1OiPfkdCKdXZyr5znvVkxuPd0NOdZCQc= +github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU= github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tstranex/u2f v1.0.0 h1:HhJkSzDDlVSVIVt7pDJwCHQj67k7A5EeBgPmeD+pVsQ= @@ -163,7 +169,13 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC github.com/valyala/fasthttp v1.6.0 h1:uWF8lgKmeaIewWVPwi4GRq2P6+R46IgYZdxWtM+GtEY= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +go.mongodb.org/mongo-driver v1.1.3 h1:++7u8r9adKhGR+I79NfEtYrk2ktjenErXM99PSufIoI= +go.mongodb.org/mongo-driver v1.1.3/go.mod h1:u7ryQJ+DOzQmeO7zB6MHyr8jkEQvC8vH7qLUO4lqsUM= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= golang.org/x/crypto v0.0.0-20181112202954-3d3f9f413869/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -233,6 +245,7 @@ golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190624190245-7f2218787638 h1:uIfBkD8gLczr4XDgYpt/qJYds2YJwZRNw4zs7wSnNhk= golang.org/x/tools v0.0.0-20190624190245-7f2218787638/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= diff --git a/handlers/handler_register_u2f_step1.go b/handlers/handler_register_u2f_step1.go index 0dd6c2d31..5cb3fe8c3 100644 --- a/handlers/handler_register_u2f_step1.go +++ b/handlers/handler_register_u2f_step1.go @@ -25,7 +25,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost()) - ctx.Logger.Debugf("U2F appID is %s", appID) + ctx.Logger.Tracef("U2F appID is %s", appID) var trustedFacets = []string{appID} challenge, err := u2f.NewChallenge(appID, trustedFacets) diff --git a/handlers/handler_register_u2f_step2.go b/handlers/handler_register_u2f_step2.go index 83857ac0b..c666cec91 100644 --- a/handlers/handler_register_u2f_step2.go +++ b/handlers/handler_register_u2f_step2.go @@ -1,6 +1,7 @@ package handlers import ( + "crypto/elliptic" "fmt" "github.com/clems4ever/authelia/middlewares" @@ -32,14 +33,15 @@ func SecondFactorU2FRegister(ctx *middlewares.AutheliaCtx) { return } - deviceHandle, err := registration.MarshalBinary() if err != nil { ctx.Error(fmt.Errorf("Unable to marshal U2F registration data: %v", err), unableToRegisterSecurityKeyMessage) return } ctx.Logger.Debugf("Register U2F device for user %s", userSession.Username) - err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, deviceHandle) + + publicKey := elliptic.Marshal(elliptic.P256(), registration.PubKey.X, registration.PubKey.Y) + err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, registration.KeyHandle, publicKey) if err != nil { ctx.Error(fmt.Errorf("Unable to register U2F device for user %s: %v", userSession.Username, err), unableToRegisterSecurityKeyMessage) diff --git a/handlers/handler_sign_u2f_step1.go b/handlers/handler_sign_u2f_step1.go index b8198d8b2..4d9a47ea7 100644 --- a/handlers/handler_sign_u2f_step1.go +++ b/handlers/handler_sign_u2f_step1.go @@ -1,9 +1,11 @@ package handlers import ( + "crypto/elliptic" "fmt" "github.com/clems4ever/authelia/middlewares" + "github.com/clems4ever/authelia/session" "github.com/clems4ever/authelia/storage" "github.com/tstranex/u2f" ) @@ -21,7 +23,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) { return } - registrationBin, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username) + keyHandleBytes, publicKeyBytes, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username) if err != nil { if err == storage.ErrNoU2FDeviceHandle { @@ -32,20 +34,18 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) { return } - if len(registrationBin) == 0 { - ctx.Error(fmt.Errorf("Wrong format of device handler for user %s", userSession.Username), mfaValidationFailedMessage) - return - } - var registration u2f.Registration - err = registration.UnmarshalBinary(registrationBin) - if err != nil { - ctx.Error(fmt.Errorf("Unable to unmarshal U2F device handle: %s", err), mfaValidationFailedMessage) - return - } + registration.KeyHandle = keyHandleBytes + x, y := elliptic.Unmarshal(elliptic.P256(), publicKeyBytes) + registration.PubKey.Curve = elliptic.P256() + registration.PubKey.X = x + registration.PubKey.Y = y // Save the challenge and registration for use in next request - userSession.U2FRegistration = ®istration + userSession.U2FRegistration = &session.U2FRegistration{ + KeyHandle: keyHandleBytes, + PublicKey: publicKeyBytes, + } userSession.U2FChallenge = challenge err = ctx.SaveSession(userSession) diff --git a/handlers/handler_sign_u2f_step2.go b/handlers/handler_sign_u2f_step2.go index 4f54a4b3b..e0cee6077 100644 --- a/handlers/handler_sign_u2f_step2.go +++ b/handlers/handler_sign_u2f_step2.go @@ -1,11 +1,13 @@ package handlers import ( + "crypto/elliptic" "fmt" "net/url" "github.com/clems4ever/authelia/authentication" "github.com/clems4ever/authelia/middlewares" + "github.com/tstranex/u2f" ) // SecondFactorU2FSignPost handler for completing a signing request. @@ -29,8 +31,15 @@ func SecondFactorU2FSignPost(ctx *middlewares.AutheliaCtx) { return } + var registration u2f.Registration + registration.KeyHandle = userSession.U2FRegistration.KeyHandle + x, y := elliptic.Unmarshal(elliptic.P256(), userSession.U2FRegistration.PublicKey) + registration.PubKey.Curve = elliptic.P256() + registration.PubKey.X = x + registration.PubKey.Y = y + // TODO(c.michaud): store the counter to help detecting cloned U2F keys. - _, err = userSession.U2FRegistration.Authenticate( + _, err = registration.Authenticate( requestBody.SignResponse, *userSession.U2FChallenge, 0) if err != nil { diff --git a/middlewares/authelia_context.go b/middlewares/authelia_context.go index 806a5449b..4bd3b736a 100644 --- a/middlewares/authelia_context.go +++ b/middlewares/authelia_context.go @@ -24,7 +24,7 @@ func NewAutheliaCtx(ctx *fasthttp.RequestCtx, configuration schema.Configuration userSession, err := providers.SessionProvider.GetSession(ctx) if err != nil { - return nil, fmt.Errorf("Unable to retrieve user session: %s", err.Error()) + return autheliaCtx, fmt.Errorf("Unable to retrieve user session: %s", err.Error()) } autheliaCtx.userSession = userSession @@ -47,7 +47,12 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers) // Error reply with an error and display the stack trace in the logs. func (c *AutheliaCtx) Error(err error, message string) { - b, _ := json.Marshal(ErrorResponse{Status: "KO", Message: message}) + b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message}) + + if marshalErr != nil { + c.Logger.Error(marshalErr) + } + c.SetContentType("application/json") c.SetBody(b) c.Logger.Error(err) @@ -55,7 +60,12 @@ func (c *AutheliaCtx) Error(err error, message string) { // ReplyError reply with an error but does not display any stack trace in the logs func (c *AutheliaCtx) ReplyError(err error, message string) { - b, _ := json.Marshal(ErrorResponse{Status: "KO", Message: message}) + b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message}) + + if marshalErr != nil { + c.Logger.Error(marshalErr) + } + c.SetContentType("application/json") c.SetBody(b) c.Logger.Debug(err) diff --git a/middlewares/identity_verification.go b/middlewares/identity_verification.go index 8bbe7ccc1..df659a193 100644 --- a/middlewares/identity_verification.go +++ b/middlewares/identity_verification.go @@ -65,7 +65,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle return } - ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a TOTP device.", + ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a device.", identity.Username, identity.Email) err = ctx.Providers.Notifier.Send(identity.Email, args.MailSubject, buf.String()) diff --git a/mocks/mock_authelia_ctx.go b/mocks/mock_authelia_ctx.go index bd0746477..cdf1868b6 100644 --- a/mocks/mock_authelia_ctx.go +++ b/mocks/mock_authelia_ctx.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/clems4ever/authelia/regulation" + "github.com/clems4ever/authelia/storage" "github.com/stretchr/testify/assert" "github.com/clems4ever/authelia/authorization" @@ -27,7 +28,7 @@ type MockAutheliaCtx struct { // Providers UserProviderMock *MockUserProvider - StorageProviderMock *MockStorageProvider + StorageProviderMock *storage.MockProvider NotifierMock *MockNotifier UserSession *session.UserSession @@ -62,7 +63,7 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx { mockAuthelia.UserProviderMock = NewMockUserProvider(mockAuthelia.Ctrl) providers.UserProvider = mockAuthelia.UserProviderMock - mockAuthelia.StorageProviderMock = NewMockStorageProvider(mockAuthelia.Ctrl) + mockAuthelia.StorageProviderMock = storage.NewMockProvider(mockAuthelia.Ctrl) providers.StorageProvider = mockAuthelia.StorageProviderMock mockAuthelia.NotifierMock = NewMockNotifier(mockAuthelia.Ctrl) diff --git a/mocks/mock_storage_provider.go b/mocks/mock_storage_provider.go deleted file mode 100644 index 3c923630f..000000000 --- a/mocks/mock_storage_provider.go +++ /dev/null @@ -1,195 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: storage/provider.go - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - time "time" - - models "github.com/clems4ever/authelia/models" - gomock "github.com/golang/mock/gomock" -) - -// MockStorageProvider is a mock of Provider interface -type MockStorageProvider struct { - ctrl *gomock.Controller - recorder *MockStorageProviderMockRecorder -} - -// MockStorageProviderMockRecorder is the mock recorder for MockStorageProvider -type MockStorageProviderMockRecorder struct { - mock *MockStorageProvider -} - -// NewMockStorageProvider creates a new mock instance -func NewMockStorageProvider(ctrl *gomock.Controller) *MockStorageProvider { - mock := &MockStorageProvider{ctrl: ctrl} - mock.recorder = &MockStorageProviderMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockStorageProvider) EXPECT() *MockStorageProviderMockRecorder { - return m.recorder -} - -// LoadPrefered2FAMethod mocks base method -func (m *MockStorageProvider) LoadPrefered2FAMethod(username string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadPrefered2FAMethod", username) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadPrefered2FAMethod indicates an expected call of LoadPrefered2FAMethod -func (mr *MockStorageProviderMockRecorder) LoadPrefered2FAMethod(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPrefered2FAMethod", reflect.TypeOf((*MockStorageProvider)(nil).LoadPrefered2FAMethod), username) -} - -// SavePrefered2FAMethod mocks base method -func (m *MockStorageProvider) SavePrefered2FAMethod(username, method string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SavePrefered2FAMethod", username, method) - ret0, _ := ret[0].(error) - return ret0 -} - -// SavePrefered2FAMethod indicates an expected call of SavePrefered2FAMethod -func (mr *MockStorageProviderMockRecorder) SavePrefered2FAMethod(username, method interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePrefered2FAMethod", reflect.TypeOf((*MockStorageProvider)(nil).SavePrefered2FAMethod), username, method) -} - -// FindIdentityVerificationToken mocks base method -func (m *MockStorageProvider) FindIdentityVerificationToken(token string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FindIdentityVerificationToken", token) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken -func (mr *MockStorageProviderMockRecorder) FindIdentityVerificationToken(token interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).FindIdentityVerificationToken), token) -} - -// SaveIdentityVerificationToken mocks base method -func (m *MockStorageProvider) SaveIdentityVerificationToken(token string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", token) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken -func (mr *MockStorageProviderMockRecorder) SaveIdentityVerificationToken(token interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).SaveIdentityVerificationToken), token) -} - -// RemoveIdentityVerificationToken mocks base method -func (m *MockStorageProvider) RemoveIdentityVerificationToken(token string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", token) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken -func (mr *MockStorageProviderMockRecorder) RemoveIdentityVerificationToken(token interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).RemoveIdentityVerificationToken), token) -} - -// SaveTOTPSecret mocks base method -func (m *MockStorageProvider) SaveTOTPSecret(username, secret string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveTOTPSecret", username, secret) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveTOTPSecret indicates an expected call of SaveTOTPSecret -func (mr *MockStorageProviderMockRecorder) SaveTOTPSecret(username, secret interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockStorageProvider)(nil).SaveTOTPSecret), username, secret) -} - -// LoadTOTPSecret mocks base method -func (m *MockStorageProvider) LoadTOTPSecret(username string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadTOTPSecret", username) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadTOTPSecret indicates an expected call of LoadTOTPSecret -func (mr *MockStorageProviderMockRecorder) LoadTOTPSecret(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockStorageProvider)(nil).LoadTOTPSecret), username) -} - -// SaveU2FDeviceHandle mocks base method -func (m *MockStorageProvider) SaveU2FDeviceHandle(username string, device []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", username, device) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle -func (mr *MockStorageProviderMockRecorder) SaveU2FDeviceHandle(username, device interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockStorageProvider)(nil).SaveU2FDeviceHandle), username, device) -} - -// LoadU2FDeviceHandle mocks base method -func (m *MockStorageProvider) LoadU2FDeviceHandle(username string) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", username) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle -func (mr *MockStorageProviderMockRecorder) LoadU2FDeviceHandle(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockStorageProvider)(nil).LoadU2FDeviceHandle), username) -} - -// AppendAuthenticationLog mocks base method -func (m *MockStorageProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendAuthenticationLog", attempt) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog -func (mr *MockStorageProviderMockRecorder) AppendAuthenticationLog(attempt interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorageProvider)(nil).AppendAuthenticationLog), attempt) -} - -// LoadLatestAuthenticationLogs mocks base method -func (m *MockStorageProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", username, fromDate) - ret0, _ := ret[0].([]models.AuthenticationAttempt) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs -func (mr *MockStorageProviderMockRecorder) LoadLatestAuthenticationLogs(username, fromDate interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockStorageProvider)(nil).LoadLatestAuthenticationLogs), username, fromDate) -} diff --git a/regulation/const.go b/regulation/const.go index bfcd57f24..bca58993e 100644 --- a/regulation/const.go +++ b/regulation/const.go @@ -2,4 +2,5 @@ package regulation import "fmt" +// ErrUserIsBanned user is banned error message var ErrUserIsBanned = fmt.Errorf("User is banned") diff --git a/regulation/regulator_test.go b/regulation/regulator_test.go index 637816326..ed361b0b4 100644 --- a/regulation/regulator_test.go +++ b/regulation/regulator_test.go @@ -5,9 +5,9 @@ import ( "time" "github.com/clems4ever/authelia/configuration/schema" - "github.com/clems4ever/authelia/mocks" "github.com/clems4ever/authelia/models" "github.com/clems4ever/authelia/regulation" + "github.com/clems4ever/authelia/storage" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -17,14 +17,14 @@ type RegulatorSuite struct { suite.Suite ctrl *gomock.Controller - storageMock *mocks.MockStorageProvider + storageMock *storage.MockProvider configuration schema.RegulationConfiguration now time.Time } func (s *RegulatorSuite) SetupTest() { s.ctrl = gomock.NewController(s.T()) - s.storageMock = mocks.NewMockStorageProvider(s.ctrl) + s.storageMock = storage.NewMockProvider(s.ctrl) s.configuration = schema.RegulationConfiguration{ MaxRetries: 3, diff --git a/session/types.go b/session/types.go index 33ab8e65d..a1a1f3ef2 100644 --- a/session/types.go +++ b/session/types.go @@ -13,6 +13,12 @@ type ProviderConfig struct { providerConfig session.ProviderConfig } +// U2FRegistration is a serializable version of a U2F registration +type U2FRegistration struct { + KeyHandle []byte + PublicKey []byte +} + // UserSession is the structure representing the session of a user. type UserSession struct { Username string @@ -29,7 +35,7 @@ type UserSession struct { U2FChallenge *u2f.Challenge // The registration representing a U2F device in DB set after identity verification. // This is used in second phase of a U2F authentication. - U2FRegistration *u2f.Registration + U2FRegistration *U2FRegistration // This boolean is set to true after identity verification and checked // while doing the query actually updating the password. diff --git a/storage/constants.go b/storage/constants.go index 6ce860dfd..80de0858c 100644 --- a/storage/constants.go +++ b/storage/constants.go @@ -1,6 +1,6 @@ package storage -var preferencesTableName = "PreferencesTableName" +var preferencesTableName = "Preferences" var identityVerificationTokensTableName = "IdentityVerificationTokens" var totpSecretsTableName = "TOTPSecrets" var u2fDeviceHandlesTableName = "U2FDeviceHandles" diff --git a/storage/mysql_provider.go b/storage/mysql_provider.go index 8c5866602..bff40d48d 100644 --- a/storage/mysql_provider.go +++ b/storage/mysql_provider.go @@ -53,8 +53,8 @@ func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProv sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName), sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName), - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=?", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, deviceHandle) VALUES (?, ?)", u2fDeviceHandlesTableName), + sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName), + sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName), sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName), sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName), diff --git a/storage/postgres_provider.go b/storage/postgres_provider.go index 3f2aa825d..e69bb135c 100644 --- a/storage/postgres_provider.go +++ b/storage/postgres_provider.go @@ -61,8 +61,8 @@ func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration) sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=$1", totpSecretsTableName), sqlUpsertTOTPSecret: fmt.Sprintf("INSERT INTO %s (username, secret) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET secret=$2", totpSecretsTableName), - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=$1", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, deviceHandle) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET deviceHandle=$2", u2fDeviceHandlesTableName), + sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=$1", u2fDeviceHandlesTableName), + sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, keyHandle, publicKey) VALUES ($1, $2, $3) ON CONFLICT (username) DO UPDATE SET keyHandle=$2, publicKey=$3", u2fDeviceHandlesTableName), sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES ($1, $2, $3)", authenticationLogsTableName), sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>$1 AND username=$2 ORDER BY time DESC", authenticationLogsTableName), diff --git a/storage/provider.go b/storage/provider.go index 99ce3f949..9a79d589f 100644 --- a/storage/provider.go +++ b/storage/provider.go @@ -19,8 +19,8 @@ type Provider interface { SaveTOTPSecret(username string, secret string) error LoadTOTPSecret(username string) (string, error) - SaveU2FDeviceHandle(username string, device []byte) error - LoadU2FDeviceHandle(username string) ([]byte, error) + SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error + LoadU2FDeviceHandle(username string) ([]byte, []byte, error) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) diff --git a/storage/provider_mock.go b/storage/provider_mock.go new file mode 100644 index 000000000..984f5eaff --- /dev/null +++ b/storage/provider_mock.go @@ -0,0 +1,195 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/clems4ever/authelia/storage (interfaces: Provider) + +// Package storage is a generated GoMock package. +package storage + +import ( + models "github.com/clems4ever/authelia/models" + gomock "github.com/golang/mock/gomock" + reflect "reflect" + time "time" +) + +// MockProvider is a mock of Provider interface +type MockProvider struct { + ctrl *gomock.Controller + recorder *MockProviderMockRecorder +} + +// MockProviderMockRecorder is the mock recorder for MockProvider +type MockProviderMockRecorder struct { + mock *MockProvider +} + +// NewMockProvider creates a new mock instance +func NewMockProvider(ctrl *gomock.Controller) *MockProvider { + mock := &MockProvider{ctrl: ctrl} + mock.recorder = &MockProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockProvider) EXPECT() *MockProviderMockRecorder { + return m.recorder +} + +// AppendAuthenticationLog mocks base method +func (m *MockProvider) AppendAuthenticationLog(arg0 models.AuthenticationAttempt) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog +func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0) +} + +// FindIdentityVerificationToken mocks base method +func (m *MockProvider) FindIdentityVerificationToken(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindIdentityVerificationToken", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken +func (mr *MockProviderMockRecorder) FindIdentityVerificationToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerificationToken), arg0) +} + +// LoadLatestAuthenticationLogs mocks base method +func (m *MockProvider) LoadLatestAuthenticationLogs(arg0 string, arg1 time.Time) ([]models.AuthenticationAttempt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", arg0, arg1) + ret0, _ := ret[0].([]models.AuthenticationAttempt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs +func (mr *MockProviderMockRecorder) LoadLatestAuthenticationLogs(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadLatestAuthenticationLogs), arg0, arg1) +} + +// LoadPrefered2FAMethod mocks base method +func (m *MockProvider) LoadPrefered2FAMethod(arg0 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadPrefered2FAMethod", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadPrefered2FAMethod indicates an expected call of LoadPrefered2FAMethod +func (mr *MockProviderMockRecorder) LoadPrefered2FAMethod(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPrefered2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPrefered2FAMethod), arg0) +} + +// LoadTOTPSecret mocks base method +func (m *MockProvider) LoadTOTPSecret(arg0 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadTOTPSecret", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadTOTPSecret indicates an expected call of LoadTOTPSecret +func (mr *MockProviderMockRecorder) LoadTOTPSecret(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockProvider)(nil).LoadTOTPSecret), arg0) +} + +// LoadU2FDeviceHandle mocks base method +func (m *MockProvider) LoadU2FDeviceHandle(arg0 string) ([]byte, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle +func (mr *MockProviderMockRecorder) LoadU2FDeviceHandle(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).LoadU2FDeviceHandle), arg0) +} + +// RemoveIdentityVerificationToken mocks base method +func (m *MockProvider) RemoveIdentityVerificationToken(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken +func (mr *MockProviderMockRecorder) RemoveIdentityVerificationToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerificationToken), arg0) +} + +// SaveIdentityVerificationToken mocks base method +func (m *MockProvider) SaveIdentityVerificationToken(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken +func (mr *MockProviderMockRecorder) SaveIdentityVerificationToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerificationToken), arg0) +} + +// SavePrefered2FAMethod mocks base method +func (m *MockProvider) SavePrefered2FAMethod(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SavePrefered2FAMethod", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SavePrefered2FAMethod indicates an expected call of SavePrefered2FAMethod +func (mr *MockProviderMockRecorder) SavePrefered2FAMethod(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePrefered2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePrefered2FAMethod), arg0, arg1) +} + +// SaveTOTPSecret mocks base method +func (m *MockProvider) SaveTOTPSecret(arg0, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveTOTPSecret", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveTOTPSecret indicates an expected call of SaveTOTPSecret +func (mr *MockProviderMockRecorder) SaveTOTPSecret(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockProvider)(nil).SaveTOTPSecret), arg0, arg1) +} + +// SaveU2FDeviceHandle mocks base method +func (m *MockProvider) SaveU2FDeviceHandle(arg0 string, arg1, arg2 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle +func (mr *MockProviderMockRecorder) SaveU2FDeviceHandle(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).SaveU2FDeviceHandle), arg0, arg1, arg2) +} diff --git a/storage/sql_provider.go b/storage/sql_provider.go index 4a023ea01..ebb238c42 100644 --- a/storage/sql_provider.go +++ b/storage/sql_provider.go @@ -48,7 +48,8 @@ func (p *SQLProvider) initialize(db *sql.DB) error { return err } - _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, deviceHandle TEXT)", u2fDeviceHandlesTableName)) + // keyHandle and publicKey are stored in base64 format + _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)", u2fDeviceHandlesTableName)) if err != nil { return err } @@ -135,22 +136,37 @@ func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { } // SaveU2FDeviceHandle save a registered U2F device registration blob. -func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte) error { - _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, username, base64.StdEncoding.EncodeToString(keyHandle)) +func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error { + _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, + username, + base64.StdEncoding.EncodeToString(keyHandle), + base64.StdEncoding.EncodeToString(publicKey)) return err } // LoadU2FDeviceHandle load a U2F device registration blob for a given username. -func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, error) { - var deviceHandle string - if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&deviceHandle); err != nil { +func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) { + var keyHandleBase64, publicKeyBase64 string + if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil { if err == sql.ErrNoRows { - return nil, ErrNoU2FDeviceHandle + return nil, nil, ErrNoU2FDeviceHandle } - return nil, err + return nil, nil, err } - return base64.StdEncoding.DecodeString(deviceHandle) + keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64) + + if err != nil { + return nil, nil, err + } + + publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) + + if err != nil { + return nil, nil, err + } + + return keyHandle, publicKey, nil } // AppendAuthenticationLog append a mark to the authentication log. diff --git a/storage/sqlite_provider.go b/storage/sqlite_provider.go index 342f851a6..4ef56471d 100644 --- a/storage/sqlite_provider.go +++ b/storage/sqlite_provider.go @@ -32,8 +32,8 @@ func NewSQLiteProvider(path string) *SQLiteProvider { sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName), sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName), - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=?", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, deviceHandle) VALUES (?, ?)", u2fDeviceHandlesTableName), + sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName), + sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName), sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName), sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),