Provide commands to migrate database from v3 to v4.

pull/448/head
Clement Michaud 2019-11-17 02:05:46 +01:00 committed by Clément Michaud
parent 6303485fd2
commit a06b69dd45
27 changed files with 715 additions and 244 deletions

View File

@ -1,3 +1,3 @@
#!/bin/bash
go run cmd/authelia-scripts/*.go $*
go run -tags migration cmd/authelia-scripts/*.go $*

View File

@ -80,6 +80,14 @@ var Commands = []AutheliaCommandDefinition{
Short: "Run unit tests",
Func: RunUnitTest,
},
AutheliaCommandDefinition{
Name: "migrate",
Short: "Migrate data from v3 to v4",
SubCommands: CobraCommands{
MigrateLocalCmd,
MigrateMongoCmd,
},
},
}
func levelStringToLevel(level string) log.Level {

View File

@ -0,0 +1,67 @@
package main
import (
"encoding/base64"
"strings"
"github.com/clems4ever/authelia/configuration"
"github.com/clems4ever/authelia/storage"
)
// TOTPSecretsV3 one entry of TOTP secrets in v3
type TOTPSecretsV3 struct {
UserID string `json:"userId"`
Secret struct {
Base32 string `json:"base32"`
} `json:"secret"`
}
// U2FDeviceHandleV3 one entry of U2F device handle in v3
type U2FDeviceHandleV3 struct {
UserID string `json:"userId"`
Registration struct {
KeyHandle string `json:"keyHandle"`
PublicKey string `json:"publicKey"`
} `json:"registration"`
}
// PreferencesV3 one entry of preferences in v3
type PreferencesV3 struct {
UserID string `json:"userId"`
Method string `json:"method"`
}
// AuthenticationTraceV3 one authentication trace in v3
type AuthenticationTraceV3 struct {
UserID string `json:"userId"`
Successful bool `json:"isAuthenticationSuccessful"`
Date struct {
Date int64 `json:"$$date"`
} `json:"date"`
}
func decodeWebsafeBase64(s string) ([]byte, error) {
s = strings.ReplaceAll(s, "_", "/")
s = strings.ReplaceAll(s, "-", "+")
for len(s)%4 != 0 {
s += "="
}
return base64.StdEncoding.DecodeString(s)
}
func createDBProvider(configurationPath string) storage.Provider {
config, _ := configuration.Read(configurationPath)
var dbProvider storage.Provider
if config.Storage.Local != nil {
dbProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
} else if config.Storage.MySQL != nil {
dbProvider = storage.NewMySQLProvider(*config.Storage.MySQL)
} else if config.Storage.PostgreSQL != nil {
dbProvider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL)
}
return dbProvider
}

View File

@ -0,0 +1,152 @@
package main
import (
"bufio"
"encoding/json"
"log"
"os"
"path"
"time"
"github.com/clems4ever/authelia/models"
"github.com/clems4ever/authelia/storage"
"github.com/spf13/cobra"
)
var configurationPath string
var localDatabasePath string
// MigrateLocalCmd migration command
var MigrateLocalCmd = &cobra.Command{
Use: "localdb",
Short: "Migrate data from v3 local database into database configured in v4 configuration file",
Run: migrateLocal,
}
func init() {
MigrateLocalCmd.PersistentFlags().StringVarP(&localDatabasePath, "db-path", "p", "", "The path to the v3 local database")
MigrateLocalCmd.MarkPersistentFlagRequired("db-path")
MigrateLocalCmd.PersistentFlags().StringVarP(&configurationPath, "config", "c", "", "The configuration file of Authelia v4")
MigrateLocalCmd.MarkPersistentFlagRequired("config")
}
// migrateLocal data from v3 to v4
func migrateLocal(cmd *cobra.Command, args []string) {
dbProvider := createDBProvider(configurationPath)
migrateLocalTOTPSecret(dbProvider)
migrateLocalU2FSecret(dbProvider)
migrateLocalPreferences(dbProvider)
migrateLocalAuthenticationTraces(dbProvider)
// We don't need to migrate identity tokens
log.Println("Migration done!")
}
func migrateLocalTOTPSecret(dbProvider storage.Provider) {
file, err := os.Open(path.Join(localDatabasePath, "totp_secrets"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
data := scanner.Text()
entry := TOTPSecretsV3{}
json.Unmarshal([]byte(data), &entry)
err := dbProvider.SaveTOTPSecret(entry.UserID, entry.Secret.Base32)
if err != nil {
log.Fatal(err)
}
}
}
func migrateLocalU2FSecret(dbProvider storage.Provider) {
file, err := os.Open(path.Join(localDatabasePath, "u2f_registrations"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
data := scanner.Text()
entry := U2FDeviceHandleV3{}
json.Unmarshal([]byte(data), &entry)
kH, err := decodeWebsafeBase64(entry.Registration.KeyHandle)
if err != nil {
log.Fatal(err)
}
pK, err := decodeWebsafeBase64(entry.Registration.PublicKey)
if err != nil {
log.Fatal(err)
}
err = dbProvider.SaveU2FDeviceHandle(entry.UserID, kH, pK)
if err != nil {
log.Fatal(err)
}
}
}
func migrateLocalPreferences(dbProvider storage.Provider) {
file, err := os.Open(path.Join(localDatabasePath, "prefered_2fa_method"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
data := scanner.Text()
entry := PreferencesV3{}
json.Unmarshal([]byte(data), &entry)
err := dbProvider.SavePrefered2FAMethod(entry.UserID, entry.Method)
if err != nil {
log.Fatal(err)
}
}
}
func migrateLocalAuthenticationTraces(dbProvider storage.Provider) {
file, err := os.Open(path.Join(localDatabasePath, "authentication_traces"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
data := scanner.Text()
entry := AuthenticationTraceV3{}
json.Unmarshal([]byte(data), &entry)
attempt := models.AuthenticationAttempt{
Username: entry.UserID,
Successful: entry.Successful,
Time: time.Unix(entry.Date.Date/1000.0, 0),
}
err := dbProvider.AppendAuthenticationLog(attempt)
if err != nil {
log.Fatal(err)
}
}
}

View File

@ -0,0 +1,184 @@
package main
import (
"context"
"log"
"time"
"github.com/clems4ever/authelia/models"
"github.com/clems4ever/authelia/storage"
"github.com/spf13/cobra"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
var mongoURL string
var mongoDatabase string
// MigrateMongoCmd migration command
var MigrateMongoCmd = &cobra.Command{
Use: "mongo",
Short: "Migrate data from v3 mongo database into database configured in v4 configuration file",
Run: migrateMongo,
}
func init() {
MigrateMongoCmd.PersistentFlags().StringVar(&mongoURL, "url", "", "The address to the mongo server")
MigrateMongoCmd.MarkPersistentFlagRequired("url")
MigrateMongoCmd.PersistentFlags().StringVar(&mongoDatabase, "database", "", "The mongo database")
MigrateMongoCmd.MarkPersistentFlagRequired("database")
MigrateMongoCmd.PersistentFlags().StringVarP(&configurationPath, "config", "c", "", "The configuration file of Authelia v4")
MigrateMongoCmd.MarkPersistentFlagRequired("config")
}
func migrateMongo(cmd *cobra.Command, args []string) {
dbProvider := createDBProvider(configurationPath)
client, err := mongo.NewClient(options.Client().ApplyURI(mongoURL))
if err != nil {
log.Fatal(err)
}
err = client.Connect(context.Background())
if err != nil {
log.Fatal(err)
}
db := client.Database(mongoDatabase)
migrateMongoU2FDevices(db, dbProvider)
migrateMongoTOTPDevices(db, dbProvider)
migrateMongoPreferences(db, dbProvider)
log.Println("Migration done!")
}
func migrateMongoU2FDevices(db *mongo.Database, dbProvider storage.Provider) {
u2fCollection := db.Collection("u2f_registrations")
cur, err := u2fCollection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
var result U2FDeviceHandleV3
err := cur.Decode(&result)
if err != nil {
log.Fatal(err)
}
kH, err := decodeWebsafeBase64(result.Registration.KeyHandle)
if err != nil {
log.Fatal(err)
}
pK, err := decodeWebsafeBase64(result.Registration.PublicKey)
if err != nil {
log.Fatal(err)
}
err = dbProvider.SaveU2FDeviceHandle(result.UserID, kH, pK)
if err != nil {
log.Fatal(err)
}
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
}
func migrateMongoTOTPDevices(db *mongo.Database, dbProvider storage.Provider) {
u2fCollection := db.Collection("totp_secrets")
cur, err := u2fCollection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
var result TOTPSecretsV3
err := cur.Decode(&result)
if err != nil {
log.Fatal(err)
}
err = dbProvider.SaveTOTPSecret(result.UserID, result.Secret.Base32)
if err != nil {
log.Fatal(err)
}
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
}
func migrateMongoPreferences(db *mongo.Database, dbProvider storage.Provider) {
u2fCollection := db.Collection("prefered_2fa_method")
cur, err := u2fCollection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
var result PreferencesV3
err := cur.Decode(&result)
if err != nil {
log.Fatal(err)
}
err = dbProvider.SavePrefered2FAMethod(result.UserID, result.Method)
if err != nil {
log.Fatal(err)
}
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
}
func migrateMongoAuthenticationTraces(db *mongo.Database, dbProvider storage.Provider) {
u2fCollection := db.Collection("authentication_traces")
cur, err := u2fCollection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
var result AuthenticationTraceV3
err := cur.Decode(&result)
if err != nil {
log.Fatal(err)
}
attempt := models.AuthenticationAttempt{
Username: result.UserID,
Successful: result.Successful,
Time: time.Unix(result.Date.Date/1000.0, 0),
}
err = dbProvider.AppendAuthenticationLog(attempt)
if err != nil {
log.Fatal(err)
}
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
}

View File

@ -14,7 +14,5 @@ services:
environment:
- SUITE_PATH=${SUITE_PATH}
- ENVIRONMENT=dev
ports:
- 9091:9091
networks:
- authelianet

View File

@ -8,7 +8,5 @@ services:
working_dir: /app
volumes:
- "./client:/app"
ports:
- 3000:3000
networks:
- authelianet

6
go.mod
View File

@ -13,7 +13,9 @@ require (
github.com/fasthttp/router v0.5.2
github.com/fasthttp/session v1.1.3
github.com/go-sql-driver/mysql v1.4.1
github.com/go-stack/stack v1.8.0 // indirect
github.com/golang/mock v1.3.1
github.com/golang/snappy v0.0.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/kr/pty v1.1.8 // indirect
github.com/lib/pq v1.2.0
@ -28,8 +30,12 @@ require (
github.com/spf13/cobra v0.0.5
github.com/stretchr/testify v1.4.0
github.com/tebeka/selenium v0.9.9
github.com/tidwall/pretty v1.0.0 // indirect
github.com/tstranex/u2f v1.0.0
github.com/valyala/fasthttp v1.6.0
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c // indirect
github.com/xdg/stringprep v1.0.0 // indirect
go.mongodb.org/mongo-driver v1.1.3
google.golang.org/appengine v1.6.5 // indirect
gopkg.in/ldap.v3 v3.1.0
gopkg.in/yaml.v2 v2.2.4

13
go.sum
View File

@ -49,6 +49,8 @@ github.com/go-redis/redis v6.15.2+incompatible h1:9SpNVG76gr6InJGxoZ6IuuxaCOQwDA
github.com/go-redis/redis v6.15.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -58,6 +60,8 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@ -153,6 +157,8 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/tebeka/selenium v0.9.9 h1:cNziB+etNgyH/7KlNI7RMC1ua5aH1+5wUlFQyzeMh+w=
github.com/tebeka/selenium v0.9.9/go.mod h1:5Fr8+pUvU6B1OiPfkdCKdXZyr5znvVkxuPd0NOdZCQc=
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tstranex/u2f v1.0.0 h1:HhJkSzDDlVSVIVt7pDJwCHQj67k7A5EeBgPmeD+pVsQ=
@ -163,7 +169,13 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC
github.com/valyala/fasthttp v1.6.0 h1:uWF8lgKmeaIewWVPwi4GRq2P6+R46IgYZdxWtM+GtEY=
github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w=
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk=
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I=
github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0=
github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.mongodb.org/mongo-driver v1.1.3 h1:++7u8r9adKhGR+I79NfEtYrk2ktjenErXM99PSufIoI=
go.mongodb.org/mongo-driver v1.1.3/go.mod h1:u7ryQJ+DOzQmeO7zB6MHyr8jkEQvC8vH7qLUO4lqsUM=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
golang.org/x/crypto v0.0.0-20181112202954-3d3f9f413869/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@ -233,6 +245,7 @@ golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190624190245-7f2218787638 h1:uIfBkD8gLczr4XDgYpt/qJYds2YJwZRNw4zs7wSnNhk=
golang.org/x/tools v0.0.0-20190624190245-7f2218787638/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=

View File

@ -25,7 +25,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew
func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost())
ctx.Logger.Debugf("U2F appID is %s", appID)
ctx.Logger.Tracef("U2F appID is %s", appID)
var trustedFacets = []string{appID}
challenge, err := u2f.NewChallenge(appID, trustedFacets)

View File

@ -1,6 +1,7 @@
package handlers
import (
"crypto/elliptic"
"fmt"
"github.com/clems4ever/authelia/middlewares"
@ -32,14 +33,15 @@ func SecondFactorU2FRegister(ctx *middlewares.AutheliaCtx) {
return
}
deviceHandle, err := registration.MarshalBinary()
if err != nil {
ctx.Error(fmt.Errorf("Unable to marshal U2F registration data: %v", err), unableToRegisterSecurityKeyMessage)
return
}
ctx.Logger.Debugf("Register U2F device for user %s", userSession.Username)
err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, deviceHandle)
publicKey := elliptic.Marshal(elliptic.P256(), registration.PubKey.X, registration.PubKey.Y)
err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, registration.KeyHandle, publicKey)
if err != nil {
ctx.Error(fmt.Errorf("Unable to register U2F device for user %s: %v", userSession.Username, err), unableToRegisterSecurityKeyMessage)

View File

@ -1,9 +1,11 @@
package handlers
import (
"crypto/elliptic"
"fmt"
"github.com/clems4ever/authelia/middlewares"
"github.com/clems4ever/authelia/session"
"github.com/clems4ever/authelia/storage"
"github.com/tstranex/u2f"
)
@ -21,7 +23,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
return
}
registrationBin, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username)
keyHandleBytes, publicKeyBytes, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username)
if err != nil {
if err == storage.ErrNoU2FDeviceHandle {
@ -32,20 +34,18 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
return
}
if len(registrationBin) == 0 {
ctx.Error(fmt.Errorf("Wrong format of device handler for user %s", userSession.Username), mfaValidationFailedMessage)
return
}
var registration u2f.Registration
err = registration.UnmarshalBinary(registrationBin)
if err != nil {
ctx.Error(fmt.Errorf("Unable to unmarshal U2F device handle: %s", err), mfaValidationFailedMessage)
return
}
registration.KeyHandle = keyHandleBytes
x, y := elliptic.Unmarshal(elliptic.P256(), publicKeyBytes)
registration.PubKey.Curve = elliptic.P256()
registration.PubKey.X = x
registration.PubKey.Y = y
// Save the challenge and registration for use in next request
userSession.U2FRegistration = &registration
userSession.U2FRegistration = &session.U2FRegistration{
KeyHandle: keyHandleBytes,
PublicKey: publicKeyBytes,
}
userSession.U2FChallenge = challenge
err = ctx.SaveSession(userSession)

View File

@ -1,11 +1,13 @@
package handlers
import (
"crypto/elliptic"
"fmt"
"net/url"
"github.com/clems4ever/authelia/authentication"
"github.com/clems4ever/authelia/middlewares"
"github.com/tstranex/u2f"
)
// SecondFactorU2FSignPost handler for completing a signing request.
@ -29,8 +31,15 @@ func SecondFactorU2FSignPost(ctx *middlewares.AutheliaCtx) {
return
}
var registration u2f.Registration
registration.KeyHandle = userSession.U2FRegistration.KeyHandle
x, y := elliptic.Unmarshal(elliptic.P256(), userSession.U2FRegistration.PublicKey)
registration.PubKey.Curve = elliptic.P256()
registration.PubKey.X = x
registration.PubKey.Y = y
// TODO(c.michaud): store the counter to help detecting cloned U2F keys.
_, err = userSession.U2FRegistration.Authenticate(
_, err = registration.Authenticate(
requestBody.SignResponse, *userSession.U2FChallenge, 0)
if err != nil {

View File

@ -24,7 +24,7 @@ func NewAutheliaCtx(ctx *fasthttp.RequestCtx, configuration schema.Configuration
userSession, err := providers.SessionProvider.GetSession(ctx)
if err != nil {
return nil, fmt.Errorf("Unable to retrieve user session: %s", err.Error())
return autheliaCtx, fmt.Errorf("Unable to retrieve user session: %s", err.Error())
}
autheliaCtx.userSession = userSession
@ -47,7 +47,12 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers)
// Error reply with an error and display the stack trace in the logs.
func (c *AutheliaCtx) Error(err error, message string) {
b, _ := json.Marshal(ErrorResponse{Status: "KO", Message: message})
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
if marshalErr != nil {
c.Logger.Error(marshalErr)
}
c.SetContentType("application/json")
c.SetBody(b)
c.Logger.Error(err)
@ -55,7 +60,12 @@ func (c *AutheliaCtx) Error(err error, message string) {
// ReplyError reply with an error but does not display any stack trace in the logs
func (c *AutheliaCtx) ReplyError(err error, message string) {
b, _ := json.Marshal(ErrorResponse{Status: "KO", Message: message})
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
if marshalErr != nil {
c.Logger.Error(marshalErr)
}
c.SetContentType("application/json")
c.SetBody(b)
c.Logger.Debug(err)

View File

@ -65,7 +65,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
return
}
ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a TOTP device.",
ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a device.",
identity.Username, identity.Email)
err = ctx.Providers.Notifier.Send(identity.Email, args.MailSubject, buf.String())

View File

@ -6,6 +6,7 @@ import (
"testing"
"github.com/clems4ever/authelia/regulation"
"github.com/clems4ever/authelia/storage"
"github.com/stretchr/testify/assert"
"github.com/clems4ever/authelia/authorization"
@ -27,7 +28,7 @@ type MockAutheliaCtx struct {
// Providers
UserProviderMock *MockUserProvider
StorageProviderMock *MockStorageProvider
StorageProviderMock *storage.MockProvider
NotifierMock *MockNotifier
UserSession *session.UserSession
@ -62,7 +63,7 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx {
mockAuthelia.UserProviderMock = NewMockUserProvider(mockAuthelia.Ctrl)
providers.UserProvider = mockAuthelia.UserProviderMock
mockAuthelia.StorageProviderMock = NewMockStorageProvider(mockAuthelia.Ctrl)
mockAuthelia.StorageProviderMock = storage.NewMockProvider(mockAuthelia.Ctrl)
providers.StorageProvider = mockAuthelia.StorageProviderMock
mockAuthelia.NotifierMock = NewMockNotifier(mockAuthelia.Ctrl)

View File

@ -1,195 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: storage/provider.go
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
time "time"
models "github.com/clems4ever/authelia/models"
gomock "github.com/golang/mock/gomock"
)
// MockStorageProvider is a mock of Provider interface
type MockStorageProvider struct {
ctrl *gomock.Controller
recorder *MockStorageProviderMockRecorder
}
// MockStorageProviderMockRecorder is the mock recorder for MockStorageProvider
type MockStorageProviderMockRecorder struct {
mock *MockStorageProvider
}
// NewMockStorageProvider creates a new mock instance
func NewMockStorageProvider(ctrl *gomock.Controller) *MockStorageProvider {
mock := &MockStorageProvider{ctrl: ctrl}
mock.recorder = &MockStorageProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStorageProvider) EXPECT() *MockStorageProviderMockRecorder {
return m.recorder
}
// LoadPrefered2FAMethod mocks base method
func (m *MockStorageProvider) LoadPrefered2FAMethod(username string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadPrefered2FAMethod", username)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadPrefered2FAMethod indicates an expected call of LoadPrefered2FAMethod
func (mr *MockStorageProviderMockRecorder) LoadPrefered2FAMethod(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPrefered2FAMethod", reflect.TypeOf((*MockStorageProvider)(nil).LoadPrefered2FAMethod), username)
}
// SavePrefered2FAMethod mocks base method
func (m *MockStorageProvider) SavePrefered2FAMethod(username, method string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePrefered2FAMethod", username, method)
ret0, _ := ret[0].(error)
return ret0
}
// SavePrefered2FAMethod indicates an expected call of SavePrefered2FAMethod
func (mr *MockStorageProviderMockRecorder) SavePrefered2FAMethod(username, method interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePrefered2FAMethod", reflect.TypeOf((*MockStorageProvider)(nil).SavePrefered2FAMethod), username, method)
}
// FindIdentityVerificationToken mocks base method
func (m *MockStorageProvider) FindIdentityVerificationToken(token string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindIdentityVerificationToken", token)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken
func (mr *MockStorageProviderMockRecorder) FindIdentityVerificationToken(token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).FindIdentityVerificationToken), token)
}
// SaveIdentityVerificationToken mocks base method
func (m *MockStorageProvider) SaveIdentityVerificationToken(token string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", token)
ret0, _ := ret[0].(error)
return ret0
}
// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken
func (mr *MockStorageProviderMockRecorder) SaveIdentityVerificationToken(token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).SaveIdentityVerificationToken), token)
}
// RemoveIdentityVerificationToken mocks base method
func (m *MockStorageProvider) RemoveIdentityVerificationToken(token string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", token)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken
func (mr *MockStorageProviderMockRecorder) RemoveIdentityVerificationToken(token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockStorageProvider)(nil).RemoveIdentityVerificationToken), token)
}
// SaveTOTPSecret mocks base method
func (m *MockStorageProvider) SaveTOTPSecret(username, secret string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveTOTPSecret", username, secret)
ret0, _ := ret[0].(error)
return ret0
}
// SaveTOTPSecret indicates an expected call of SaveTOTPSecret
func (mr *MockStorageProviderMockRecorder) SaveTOTPSecret(username, secret interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockStorageProvider)(nil).SaveTOTPSecret), username, secret)
}
// LoadTOTPSecret mocks base method
func (m *MockStorageProvider) LoadTOTPSecret(username string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadTOTPSecret", username)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadTOTPSecret indicates an expected call of LoadTOTPSecret
func (mr *MockStorageProviderMockRecorder) LoadTOTPSecret(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockStorageProvider)(nil).LoadTOTPSecret), username)
}
// SaveU2FDeviceHandle mocks base method
func (m *MockStorageProvider) SaveU2FDeviceHandle(username string, device []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", username, device)
ret0, _ := ret[0].(error)
return ret0
}
// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle
func (mr *MockStorageProviderMockRecorder) SaveU2FDeviceHandle(username, device interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockStorageProvider)(nil).SaveU2FDeviceHandle), username, device)
}
// LoadU2FDeviceHandle mocks base method
func (m *MockStorageProvider) LoadU2FDeviceHandle(username string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", username)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle
func (mr *MockStorageProviderMockRecorder) LoadU2FDeviceHandle(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockStorageProvider)(nil).LoadU2FDeviceHandle), username)
}
// AppendAuthenticationLog mocks base method
func (m *MockStorageProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AppendAuthenticationLog", attempt)
ret0, _ := ret[0].(error)
return ret0
}
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog
func (mr *MockStorageProviderMockRecorder) AppendAuthenticationLog(attempt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorageProvider)(nil).AppendAuthenticationLog), attempt)
}
// LoadLatestAuthenticationLogs mocks base method
func (m *MockStorageProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", username, fromDate)
ret0, _ := ret[0].([]models.AuthenticationAttempt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs
func (mr *MockStorageProviderMockRecorder) LoadLatestAuthenticationLogs(username, fromDate interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockStorageProvider)(nil).LoadLatestAuthenticationLogs), username, fromDate)
}

View File

@ -2,4 +2,5 @@ package regulation
import "fmt"
// ErrUserIsBanned user is banned error message
var ErrUserIsBanned = fmt.Errorf("User is banned")

View File

@ -5,9 +5,9 @@ import (
"time"
"github.com/clems4ever/authelia/configuration/schema"
"github.com/clems4ever/authelia/mocks"
"github.com/clems4ever/authelia/models"
"github.com/clems4ever/authelia/regulation"
"github.com/clems4ever/authelia/storage"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
@ -17,14 +17,14 @@ type RegulatorSuite struct {
suite.Suite
ctrl *gomock.Controller
storageMock *mocks.MockStorageProvider
storageMock *storage.MockProvider
configuration schema.RegulationConfiguration
now time.Time
}
func (s *RegulatorSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.storageMock = mocks.NewMockStorageProvider(s.ctrl)
s.storageMock = storage.NewMockProvider(s.ctrl)
s.configuration = schema.RegulationConfiguration{
MaxRetries: 3,

View File

@ -13,6 +13,12 @@ type ProviderConfig struct {
providerConfig session.ProviderConfig
}
// U2FRegistration is a serializable version of a U2F registration
type U2FRegistration struct {
KeyHandle []byte
PublicKey []byte
}
// UserSession is the structure representing the session of a user.
type UserSession struct {
Username string
@ -29,7 +35,7 @@ type UserSession struct {
U2FChallenge *u2f.Challenge
// The registration representing a U2F device in DB set after identity verification.
// This is used in second phase of a U2F authentication.
U2FRegistration *u2f.Registration
U2FRegistration *U2FRegistration
// This boolean is set to true after identity verification and checked
// while doing the query actually updating the password.

View File

@ -1,6 +1,6 @@
package storage
var preferencesTableName = "PreferencesTableName"
var preferencesTableName = "Preferences"
var identityVerificationTokensTableName = "IdentityVerificationTokens"
var totpSecretsTableName = "TOTPSecrets"
var u2fDeviceHandlesTableName = "U2FDeviceHandles"

View File

@ -53,8 +53,8 @@ func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProv
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, deviceHandle) VALUES (?, ?)", u2fDeviceHandlesTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),

View File

@ -61,8 +61,8 @@ func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration)
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=$1", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("INSERT INTO %s (username, secret) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET secret=$2", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=$1", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, deviceHandle) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET deviceHandle=$2", u2fDeviceHandlesTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=$1", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, keyHandle, publicKey) VALUES ($1, $2, $3) ON CONFLICT (username) DO UPDATE SET keyHandle=$2, publicKey=$3", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES ($1, $2, $3)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>$1 AND username=$2 ORDER BY time DESC", authenticationLogsTableName),

View File

@ -19,8 +19,8 @@ type Provider interface {
SaveTOTPSecret(username string, secret string) error
LoadTOTPSecret(username string) (string, error)
SaveU2FDeviceHandle(username string, device []byte) error
LoadU2FDeviceHandle(username string) ([]byte, error)
SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error
LoadU2FDeviceHandle(username string) ([]byte, []byte, error)
AppendAuthenticationLog(attempt models.AuthenticationAttempt) error
LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error)

View File

@ -0,0 +1,195 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/clems4ever/authelia/storage (interfaces: Provider)
// Package storage is a generated GoMock package.
package storage
import (
models "github.com/clems4ever/authelia/models"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockProvider is a mock of Provider interface
type MockProvider struct {
ctrl *gomock.Controller
recorder *MockProviderMockRecorder
}
// MockProviderMockRecorder is the mock recorder for MockProvider
type MockProviderMockRecorder struct {
mock *MockProvider
}
// NewMockProvider creates a new mock instance
func NewMockProvider(ctrl *gomock.Controller) *MockProvider {
mock := &MockProvider{ctrl: ctrl}
mock.recorder = &MockProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockProvider) EXPECT() *MockProviderMockRecorder {
return m.recorder
}
// AppendAuthenticationLog mocks base method
func (m *MockProvider) AppendAuthenticationLog(arg0 models.AuthenticationAttempt) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog
func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0)
}
// FindIdentityVerificationToken mocks base method
func (m *MockProvider) FindIdentityVerificationToken(arg0 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindIdentityVerificationToken", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken
func (mr *MockProviderMockRecorder) FindIdentityVerificationToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerificationToken), arg0)
}
// LoadLatestAuthenticationLogs mocks base method
func (m *MockProvider) LoadLatestAuthenticationLogs(arg0 string, arg1 time.Time) ([]models.AuthenticationAttempt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", arg0, arg1)
ret0, _ := ret[0].([]models.AuthenticationAttempt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs
func (mr *MockProviderMockRecorder) LoadLatestAuthenticationLogs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadLatestAuthenticationLogs), arg0, arg1)
}
// LoadPrefered2FAMethod mocks base method
func (m *MockProvider) LoadPrefered2FAMethod(arg0 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadPrefered2FAMethod", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadPrefered2FAMethod indicates an expected call of LoadPrefered2FAMethod
func (mr *MockProviderMockRecorder) LoadPrefered2FAMethod(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPrefered2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPrefered2FAMethod), arg0)
}
// LoadTOTPSecret mocks base method
func (m *MockProvider) LoadTOTPSecret(arg0 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadTOTPSecret", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadTOTPSecret indicates an expected call of LoadTOTPSecret
func (mr *MockProviderMockRecorder) LoadTOTPSecret(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockProvider)(nil).LoadTOTPSecret), arg0)
}
// LoadU2FDeviceHandle mocks base method
func (m *MockProvider) LoadU2FDeviceHandle(arg0 string) ([]byte, []byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].([]byte)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle
func (mr *MockProviderMockRecorder) LoadU2FDeviceHandle(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).LoadU2FDeviceHandle), arg0)
}
// RemoveIdentityVerificationToken mocks base method
func (m *MockProvider) RemoveIdentityVerificationToken(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken
func (mr *MockProviderMockRecorder) RemoveIdentityVerificationToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerificationToken), arg0)
}
// SaveIdentityVerificationToken mocks base method
func (m *MockProvider) SaveIdentityVerificationToken(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken
func (mr *MockProviderMockRecorder) SaveIdentityVerificationToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerificationToken), arg0)
}
// SavePrefered2FAMethod mocks base method
func (m *MockProvider) SavePrefered2FAMethod(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePrefered2FAMethod", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SavePrefered2FAMethod indicates an expected call of SavePrefered2FAMethod
func (mr *MockProviderMockRecorder) SavePrefered2FAMethod(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePrefered2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePrefered2FAMethod), arg0, arg1)
}
// SaveTOTPSecret mocks base method
func (m *MockProvider) SaveTOTPSecret(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveTOTPSecret", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveTOTPSecret indicates an expected call of SaveTOTPSecret
func (mr *MockProviderMockRecorder) SaveTOTPSecret(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockProvider)(nil).SaveTOTPSecret), arg0, arg1)
}
// SaveU2FDeviceHandle mocks base method
func (m *MockProvider) SaveU2FDeviceHandle(arg0 string, arg1, arg2 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle
func (mr *MockProviderMockRecorder) SaveU2FDeviceHandle(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).SaveU2FDeviceHandle), arg0, arg1, arg2)
}

View File

@ -48,7 +48,8 @@ func (p *SQLProvider) initialize(db *sql.DB) error {
return err
}
_, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, deviceHandle TEXT)", u2fDeviceHandlesTableName))
// keyHandle and publicKey are stored in base64 format
_, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)", u2fDeviceHandlesTableName))
if err != nil {
return err
}
@ -135,22 +136,37 @@ func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
}
// SaveU2FDeviceHandle save a registered U2F device registration blob.
func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte) error {
_, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, username, base64.StdEncoding.EncodeToString(keyHandle))
func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error {
_, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle,
username,
base64.StdEncoding.EncodeToString(keyHandle),
base64.StdEncoding.EncodeToString(publicKey))
return err
}
// LoadU2FDeviceHandle load a U2F device registration blob for a given username.
func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, error) {
var deviceHandle string
if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&deviceHandle); err != nil {
func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
var keyHandleBase64, publicKeyBase64 string
if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil {
if err == sql.ErrNoRows {
return nil, ErrNoU2FDeviceHandle
return nil, nil, ErrNoU2FDeviceHandle
}
return nil, err
return nil, nil, err
}
return base64.StdEncoding.DecodeString(deviceHandle)
keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
if err != nil {
return nil, nil, err
}
publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
return nil, nil, err
}
return keyHandle, publicKey, nil
}
// AppendAuthenticationLog append a mark to the authentication log.

View File

@ -32,8 +32,8 @@ func NewSQLiteProvider(path string) *SQLiteProvider {
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT deviceHandle FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, deviceHandle) VALUES (?, ?)", u2fDeviceHandlesTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),