From 3695aa8140eb91fd54a4cd849e1340ad4c36d987 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Tue, 23 Nov 2021 20:45:38 +1100 Subject: [PATCH] feat(storage): primary key for all tables and general qol refactoring (#2431) This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database. Fixes #1337 --- cmd/authelia/main.go | 7 +- docs/configuration/authentication/ldap.md | 1 + docs/configuration/storage/migrations.md | 24 + docs/contributing/style-guide.md | 33 ++ go.mod | 84 +++- go.sum | 17 +- internal/authentication/file_user_provider.go | 3 +- internal/authentication/ldap_user_provider.go | 10 +- .../ldap_user_provider_startup.go | 17 +- .../authentication/ldap_user_provider_test.go | 15 +- internal/authentication/user_provider.go | 5 +- internal/commands/const.go | 5 + internal/commands/helpers.go | 28 ++ internal/commands/root.go | 15 +- internal/commands/storage.go | 126 +++++ internal/commands/storage_run.go | 291 ++++++++++++ internal/configuration/koanf_callbacks.go | 24 + internal/configuration/provider.go | 19 +- internal/configuration/schema/validator.go | 101 ---- .../configuration/schema/validator_test.go | 37 -- internal/configuration/sources.go | 33 ++ internal/configuration/types.go | 8 + internal/handlers/const.go | 6 + internal/handlers/handler_firstfactor.go | 8 +- internal/handlers/handler_firstfactor_test.go | 16 +- internal/handlers/handler_register_totp.go | 18 +- .../handler_register_u2f_step1_test.go | 8 +- .../handlers/handler_register_u2f_step2.go | 8 +- internal/handlers/handler_sign_totp.go | 4 +- internal/handlers/handler_sign_totp_test.go | 39 +- internal/handlers/handler_sign_u2f_step1.go | 10 +- internal/handlers/handler_user_info.go | 82 +--- internal/handlers/handler_user_info_test.go | 170 ++++--- internal/handlers/totp.go | 49 +- internal/handlers/totp_mock.go | 26 +- internal/handlers/types.go | 15 - internal/middlewares/identity_verification.go | 9 +- .../middlewares/identity_verification_test.go | 28 +- internal/middlewares/types.go | 5 - internal/mocks/mock_authelia_ctx.go | 8 + internal/mocks/mock_notifier.go | 9 +- internal/mocks/mock_user_provider.go | 37 +- .../models/model_authentication_attempt.go | 17 + .../models/model_identity_verification.go | 12 + internal/models/model_migration.go | 14 + internal/models/model_totp_configuration.go | 11 + internal/models/model_u2f_device.go | 10 + internal/models/model_userinfo.go | 16 + internal/models/type_ipaddress.go | 42 ++ internal/models/type_startup_check.go | 6 + internal/models/types.go | 13 - internal/notification/file_notifier.go | 4 +- internal/notification/notifier.go | 5 +- internal/notification/smtp_notifier.go | 4 +- internal/ntp/ntp.go | 18 +- internal/ntp/ntp_test.go | 3 +- internal/ntp/types.go | 3 + internal/oidc/provider.go | 5 +- internal/oidc/store.go | 4 +- internal/oidc/store_test.go | 22 +- internal/regulation/regulator.go | 22 +- internal/regulation/regulator_test.go | 37 +- internal/regulation/types.go | 2 +- internal/storage/const.go | 73 +-- internal/storage/errors.go | 35 +- internal/storage/migrations.go | 204 ++++++++ .../V0001.Initial_Schema.all.down.sql | 6 + .../V0001.Initial_Schema.mysql.up.sql | 55 +++ .../V0001.Initial_Schema.postgres.up.sql | 55 +++ .../V0001.Initial_Schema.sqlite.up.sql | 54 +++ internal/storage/migrations_test.go | 154 ++++++ internal/storage/mysql_provider.go | 85 ---- internal/storage/postgres_provider.go | 90 ---- internal/storage/provider.go | 45 +- internal/storage/provider_mock.go | 440 +++++++++++------ internal/storage/sql_provider.go | 391 +++++++-------- .../storage/sql_provider_backend_mysql.go | 53 +++ .../storage/sql_provider_backend_postgres.go | 72 +++ .../storage/sql_provider_backend_sqlite.go | 22 + internal/storage/sql_provider_queries.go | 125 +++++ .../storage/sql_provider_queries_special.go | 109 +++++ internal/storage/sql_provider_schema.go | 327 +++++++++++++ internal/storage/sql_provider_schema_pre1.go | 449 ++++++++++++++++++ internal/storage/sql_provider_schema_test.go | 134 ++++++ internal/storage/sql_provider_test.go | 400 ---------------- internal/storage/sqlite_provider.go | 58 --- internal/storage/sqlmock_provider.go | 60 --- internal/storage/types.go | 36 +- internal/storage/upgrades.go | 76 --- internal/suites/suite_standalone_test.go | 4 +- 90 files changed, 3602 insertions(+), 1738 deletions(-) create mode 100644 docs/configuration/storage/migrations.md create mode 100644 internal/commands/helpers.go create mode 100644 internal/commands/storage.go create mode 100644 internal/commands/storage_run.go create mode 100644 internal/models/model_authentication_attempt.go create mode 100644 internal/models/model_identity_verification.go create mode 100644 internal/models/model_migration.go create mode 100644 internal/models/model_totp_configuration.go create mode 100644 internal/models/model_u2f_device.go create mode 100644 internal/models/model_userinfo.go create mode 100644 internal/models/type_ipaddress.go create mode 100644 internal/models/type_startup_check.go delete mode 100644 internal/models/types.go create mode 100644 internal/storage/migrations.go create mode 100644 internal/storage/migrations/V0001.Initial_Schema.all.down.sql create mode 100644 internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql create mode 100644 internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql create mode 100644 internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql create mode 100644 internal/storage/migrations_test.go delete mode 100644 internal/storage/mysql_provider.go delete mode 100644 internal/storage/postgres_provider.go create mode 100644 internal/storage/sql_provider_backend_mysql.go create mode 100644 internal/storage/sql_provider_backend_postgres.go create mode 100644 internal/storage/sql_provider_backend_sqlite.go create mode 100644 internal/storage/sql_provider_queries.go create mode 100644 internal/storage/sql_provider_queries_special.go create mode 100644 internal/storage/sql_provider_schema.go create mode 100644 internal/storage/sql_provider_schema_pre1.go create mode 100644 internal/storage/sql_provider_schema_test.go delete mode 100644 internal/storage/sql_provider_test.go delete mode 100644 internal/storage/sqlite_provider.go delete mode 100644 internal/storage/sqlmock_provider.go delete mode 100644 internal/storage/upgrades.go diff --git a/cmd/authelia/main.go b/cmd/authelia/main.go index 92555dfb8..c0fc695ee 100644 --- a/cmd/authelia/main.go +++ b/cmd/authelia/main.go @@ -1,14 +1,13 @@ package main import ( + "os" + "github.com/authelia/authelia/v4/internal/commands" - "github.com/authelia/authelia/v4/internal/logging" ) func main() { - logger := logging.Logger() - if err := commands.NewRootCmd().Execute(); err != nil { - logger.Fatal(err) + os.Exit(1) } } diff --git a/docs/configuration/authentication/ldap.md b/docs/configuration/authentication/ldap.md index 52bb14af5..53c2e975d 100644 --- a/docs/configuration/authentication/ldap.md +++ b/docs/configuration/authentication/ldap.md @@ -264,4 +264,5 @@ In versions <= `4.24.0` not including the `username_attribute` placeholder will and will result in session resets when the refresh interval has expired, default of 5 minutes. [LDAP GeneralizedTime]: https://ldapwiki.com/wiki/GeneralizedTime +[username attribute]: #username_attribute [TechNet wiki]: https://social.technet.microsoft.com/wiki/contents/articles/5392.active-directory-ldap-syntax-filters.aspx diff --git a/docs/configuration/storage/migrations.md b/docs/configuration/storage/migrations.md new file mode 100644 index 000000000..00d1ed97c --- /dev/null +++ b/docs/configuration/storage/migrations.md @@ -0,0 +1,24 @@ +--- +layout: default +title: Migrations +parent: Storage Backends +grand_parent: Configuration +nav_order: 5 +--- + +Storage migrations are important for keeping your database compatible with Authelia. Authelia will automatically upgrade +your schema on startup. However, if you wish to use an older version of Authelia you may be required to manually +downgrade your schema with a version of Authelia that supports your current schema. + +## Schema Version to Authelia Version map + +This table contains a list of schema versions and the corresponding release of Authelia that shipped with that version. +This means all Authelia versions between two schema versions use the first schema version. + +For example for version pre1, it is used for all versions between it and the version 1 schema, so 4.0.0 to 4.32.2. In +this instance if you wanted to downgrade to pre1 you would need to use an Authelia binary with version 4.33.0 or higher. + +|Schema Version|Authelia Version|Notes | +|:------------:|:--------------:|:----------------------------------------------------------:| +|pre1 |4.0.0 |Downgrading to this version requires you use the --pre1 flag| +|1 |4.33.0 | | diff --git a/docs/contributing/style-guide.md b/docs/contributing/style-guide.md index 5d2af883c..1aabc517d 100644 --- a/docs/contributing/style-guide.md +++ b/docs/contributing/style-guide.md @@ -99,3 +99,36 @@ This section has the required status of the value and must be one of `yes`, `no` depends on other configuration options. If it's situational the situational usage should be documented. This is immediately followed by the styles `.label`, `.label-config`, and a traffic lights color label, i.e. if yes `.label-red`, if no `.label-green`, or if situational `.label-yellow`. + +### Storage +This section outlines some rules for storage contributions. Including but not limited to migrations, schema rules, etc. + +#### Migrations +All migrations must have an up and down migration, preferably idempotent. + +All migrations must be named in the following format: +```text +V....sql +``` + +##### version +A 4 digit version number, should be in sequential order. + +##### name +A name containing alphanumeric characters, underscores (treated as spaces), hyphens, and no spaces. + +##### engine +The target engine for the migration, options are all, mysql, postgres, and sqlite. + +#### Primary Key +All tables must have a primary key. This primary key must be an integer with auto increment enabled, or in the case of +PostgreSQL a serial type. + +#### Table/Column Names +Table and Column names must be in snake case format. This means they must have only lowercase letters, and have words +seperated by underscores. The reasoning for this is that some database engines ignore case by default and this makes it +easy to be consistent with the casing. + +#### Context +All database methods should include the context attribute so that database requests that are no longer needed are +terminated appropriately. diff --git a/go.mod b/go.mod index fd75e3a5c..89c135031 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,12 @@ module github.com/authelia/authelia/v4 -go 1.16 +go 1.17 require ( - github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4 - github.com/Workiva/go-datastructures v1.0.53 - github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef + github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/deckarep/golang-set v1.7.1 github.com/duosecurity/duo_api_golang v0.0.0-20211027140842-72da735c6f15 - github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect github.com/fasthttp/router v1.4.4 github.com/fasthttp/session/v2 v2.4.4 github.com/go-ldap/ldap/v3 v3.4.1 @@ -19,6 +16,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.14.0 + github.com/jmoiron/sqlx v1.3.1 github.com/knadh/koanf v1.3.2 github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mitchellh/mapstructure v1.4.2 @@ -30,16 +28,88 @@ require ( github.com/simia-tech/crypt v0.5.0 github.com/sirupsen/logrus v1.8.1 github.com/spf13/cobra v1.2.1 + github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 github.com/tstranex/u2f v1.0.0 github.com/valyala/fasthttp v1.31.0 - golang.org/x/sys v0.0.0-20210902050250-f475640dd07b // indirect golang.org/x/text v0.3.7 gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v2 v2.4.0 ) +require ( + github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect + github.com/andybalholm/brotli v1.0.2 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgraph-io/ristretto v0.1.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.0 // indirect + github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect + github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect + github.com/go-redis/redis/v8 v8.11.4 // indirect + github.com/gobuffalo/pop/v5 v5.3.3 // indirect + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgconn v1.10.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.2.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect + github.com/jackc/pgtype v1.9.0 // indirect + github.com/jandelgado/gcov2lcov v1.0.4 // indirect + github.com/klauspost/compress v1.13.4 // indirect + github.com/magiconair/properties v1.8.5 // indirect + github.com/mattn/goveralls v0.0.6 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/ory/go-acc v0.2.6 // indirect + github.com/ory/go-convenience v0.1.0 // indirect + github.com/ory/viper v1.7.5 // indirect + github.com/ory/x v0.0.288 // indirect + github.com/pborman/uuid v1.2.1 // indirect + github.com/pelletier/go-toml v1.9.3 // indirect + github.com/philhofer/fwd v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/savsgio/dictpool v0.0.0-20210921080634-84324d0689d7 // indirect + github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4 // indirect + github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect + github.com/spf13/afero v1.6.0 // indirect + github.com/spf13/cast v1.3.2-0.20200723214538-8d17101741c8 // indirect + github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/sqs/goreturns v0.0.0-20181028201513-538ac6014518 // indirect + github.com/subosito/gotenv v1.2.0 // indirect + github.com/tinylib/msgp v1.1.6 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/ysmood/goob v0.3.0 // indirect + github.com/ysmood/gson v0.6.4 // indirect + github.com/ysmood/leakless v0.7.0 // indirect + go.opentelemetry.io/contrib v0.20.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.20.0 // indirect + go.opentelemetry.io/otel v0.20.0 // indirect + go.opentelemetry.io/otel/metric v0.20.0 // indirect + go.opentelemetry.io/otel/trace v0.20.0 // indirect + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect + golang.org/x/mod v0.4.2 // indirect + golang.org/x/net v0.0.0-20210510120150-4163338589ed // indirect + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect + golang.org/x/tools v0.1.2 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c // indirect + google.golang.org/grpc v1.38.0 // indirect + google.golang.org/protobuf v1.26.0 // indirect + gopkg.in/ini.v1 v1.62.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) + replace ( - github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.8 + github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.9 github.com/tidwall/gjson => github.com/tidwall/gjson v1.11.0 ) diff --git a/go.sum b/go.sum index dbc861e7f..e7d12382e 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,6 @@ github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzS github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4 h1:vdT7QwBhJJEVNFMBNhRSFDRCB6O16T28VhvqRgqFyn8= github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4/go.mod h1:SvXOG8ElV28oAiG9zv91SDe5+9PfIr7PPccpr8YyXNs= @@ -66,8 +64,6 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= -github.com/Workiva/go-datastructures v1.0.53 h1:J6Y/52yX10Xc5JjXmGtWoSSxs3mZnGSaq37xZZh7Yig= -github.com/Workiva/go-datastructures v1.0.53/go.mod h1:1yZL+zfsztete+ePzZz/Zb1/t5BnDuE2Ya2MMGhzP6A= github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/ajg/form v0.0.0-20160822230020-523a5da1a92f/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= @@ -91,8 +87,8 @@ github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:l github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= -github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= -github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= github.com/aws/aws-sdk-go v1.23.19/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= @@ -825,6 +821,7 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jmoiron/sqlx v0.0.0-20180614180643-0dae4fefe7c0/go.mod h1:IiEW3SEiiErVyFdH8NTuWjSifiEQKUoyK3LNqr2kCHU= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/jmoiron/sqlx v1.3.1 h1:aLN7YINNZ7cYOPK3QC83dbM6KT0NMqVMw961TqrejlE= github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901/go.mod h1:Z86h9688Y0wesXCyonoVr47MasHilkuLMqGhRZ4Hpak= github.com/joho/godotenv v1.2.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= @@ -948,8 +945,9 @@ github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOq github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= +github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mattn/goveralls v0.0.6 h1:cr8Y0VMo/MnEZBjxNN/vh6G90SZ7IMb6lms1dzMoO+Y= github.com/mattn/goveralls v0.0.6/go.mod h1:h8b4ow6FxSPMQHF6o2ve3qsclnffZjYTNEKmLesRwqw= @@ -1304,14 +1302,12 @@ github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7V github.com/tidwall/sjson v1.1.5 h1:wsUceI/XDyZk3J1FUvuuYlK62zJv2HO2Pzb8A5EWdUE= github.com/tidwall/sjson v1.1.5/go.mod h1:VuJzsZnTowhSxWdOgsAnb886i4AjEyTkk7tNtsL7EYE= github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= -github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg= github.com/tinylib/msgp v1.1.6 h1:i+SbKraHhnrf9M5MYmvQhFnbLhAXSDWF8WWsuyRdocw= github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tstranex/u2f v1.0.0 h1:HhJkSzDDlVSVIVt7pDJwCHQj67k7A5EeBgPmeD+pVsQ= github.com/tstranex/u2f v1.0.0/go.mod h1:eahSLaqAS0zsIEv80+vXT7WanXs7MQQDg3j3wGBSayo= -github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q= github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= github.com/uber/jaeger-client-go v2.15.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= @@ -1681,9 +1677,8 @@ golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210902050250-f475640dd07b h1:S7hKs0Flbq0bbc9xgYt4stIEG1zNDFqyrPwAX2Wj/sE= -golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/authentication/file_user_provider.go b/internal/authentication/file_user_provider.go index 9947c4f75..548c19c7a 100644 --- a/internal/authentication/file_user_provider.go +++ b/internal/authentication/file_user_provider.go @@ -9,7 +9,6 @@ import ( "sync" "github.com/asaskevich/govalidator" - "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" "github.com/authelia/authelia/v4/internal/configuration/schema" @@ -208,6 +207,6 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e } // StartupCheck implements the startup check provider interface. -func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) { +func (p *FileUserProvider) StartupCheck() (err error) { return nil } diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index 5e7bd645c..01ebe9678 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -21,7 +21,7 @@ type LDAPUserProvider struct { configuration schema.LDAPAuthenticationBackendConfiguration tlsConfig *tls.Config dialOpts []ldap.DialOpt - logger *logrus.Logger + log *logrus.Logger connectionFactory LDAPConnectionFactory disableResetPassword bool @@ -72,7 +72,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura configuration: configuration, tlsConfig: tlsConfig, dialOpts: dialOpts, - logger: logging.Logger(), + log: logging.Logger(), connectionFactory: factory, disableResetPassword: disableResetPassword, } @@ -148,7 +148,7 @@ func (p *LDAPUserProvider) resolveUsersFilter(inputUsername string) (filter stri filter = strings.ReplaceAll(filter, ldapPlaceholderInput, p.ldapEscape(inputUsername)) } - p.logger.Tracef("Computed user filter is %s", filter) + p.log.Tracef("Computed user filter is %s", filter) return filter } @@ -223,7 +223,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld } } - p.logger.Tracef("Computed groups filter is %s", filter) + p.log.Tracef("Computed groups filter is %s", filter) return filter, nil } @@ -262,7 +262,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error for _, res := range sr.Entries { if len(res.Attributes) == 0 { - p.logger.Warningf("No groups retrieved from LDAP for user %s", inputUsername) + p.log.Warningf("No groups retrieved from LDAP for user %s", inputUsername) break } diff --git a/internal/authentication/ldap_user_provider_startup.go b/internal/authentication/ldap_user_provider_startup.go index 5c56df22c..5a2806a31 100644 --- a/internal/authentication/ldap_user_provider_startup.go +++ b/internal/authentication/ldap_user_provider_startup.go @@ -4,13 +4,12 @@ import ( "strings" "github.com/go-ldap/ldap/v3" - "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/configuration/schema" ) // StartupCheck implements the startup check provider interface. -func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) { +func (p *LDAPUserProvider) StartupCheck() (err error) { conn, err := p.connect(p.configuration.User, p.configuration.Password) if err != nil { return err @@ -33,7 +32,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) { // Iterate the attribute values to see what the server supports. for _, attr := range sr.Entries[0].Attributes { if attr.Name == ldapSupportedExtensionAttribute { - logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", ")) + p.log.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", ")) for _, oid := range attr.Values { if oid == ldapOIDPasswdModifyExtension { @@ -48,7 +47,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) { if !p.supportExtensionPasswdModify && !p.disableResetPassword && p.configuration.Implementation != schema.LDAPImplementationActiveDirectory { - logger.Warn("Your LDAP server implementation may not support a method for password hashing " + + p.log.Warn("Your LDAP server implementation may not support a method for password hashing " + "known to Authelia, it's strongly recommended you ensure your directory server hashes the password " + "attribute when users reset their password via Authelia.") } @@ -61,7 +60,7 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() { p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{mail_attribute}", p.configuration.MailAttribute) p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{display_name_attribute}", p.configuration.DisplayNameAttribute) - p.logger.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter) + p.log.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter) p.usersAttributes = []string{ p.configuration.DisplayNameAttribute, @@ -75,13 +74,13 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() { p.usersBaseDN = p.configuration.BaseDN } - p.logger.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN) + p.log.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN) if strings.Contains(p.configuration.UsersFilter, ldapPlaceholderInput) { p.usersFilterReplacementInput = true } - p.logger.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v", + p.log.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v", ldapPlaceholderInput, p.usersFilterReplacementInput) } @@ -96,7 +95,7 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() { p.groupsBaseDN = p.configuration.BaseDN } - p.logger.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN) + p.log.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN) if strings.Contains(p.configuration.GroupsFilter, ldapPlaceholderInput) { p.groupsFilterReplacementInput = true @@ -110,5 +109,5 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() { p.groupsFilterReplacementDN = true } - p.logger.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN) + p.log.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN) } diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go index b9b739f38..202cd636e 100644 --- a/internal/authentication/ldap_user_provider_test.go +++ b/internal/authentication/ldap_user_provider_test.go @@ -12,7 +12,6 @@ import ( "golang.org/x/text/encoding/unicode" "github.com/authelia/authelia/v4/internal/configuration/schema" - "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/utils" ) @@ -216,7 +215,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() assert.NoError(t, err) assert.True(t, ldapClient.supportExtensionPasswdModify) @@ -273,7 +272,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() assert.NoError(t, err) assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -306,7 +305,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) { DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()). Return(mockConn, errors.New("could not connect")) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() assert.EqualError(t, err, "could not connect") assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -351,7 +350,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) { gomock.InOrder(dialURL, connBind, searchOIDs, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() assert.EqualError(t, err, "could not perform the search") assert.False(t, ldapClient.supportExtensionPasswdModify) @@ -755,7 +754,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") @@ -862,7 +861,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") @@ -966,7 +965,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) { gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose) - err := ldapClient.StartupCheck(logging.Logger()) + err := ldapClient.StartupCheck() require.NoError(t, err) err = ldapClient.UpdatePassword("john", "password") diff --git a/internal/authentication/user_provider.go b/internal/authentication/user_provider.go index d70675cef..834c959d2 100644 --- a/internal/authentication/user_provider.go +++ b/internal/authentication/user_provider.go @@ -1,14 +1,15 @@ package authentication import ( - "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/models" ) // UserProvider is the interface for checking user password and // gathering user details. type UserProvider interface { + models.StartupCheck + CheckUserPassword(username string, password string) (valid bool, err error) GetDetails(username string) (details *UserDetails, err error) UpdatePassword(username string, newPassword string) (err error) - StartupCheck(logger *logrus.Logger) (err error) } diff --git a/internal/commands/const.go b/internal/commands/const.go index a1373ae12..c8cd5d1b5 100644 --- a/internal/commands/const.go +++ b/internal/commands/const.go @@ -75,3 +75,8 @@ PowerShell: PS> authelia completion powershell > authelia.ps1 # and source this file from your PowerShell profile. ` + +const ( + storageMigrateDirectionUp = "up" + storageMigrateDirectionDown = "down" +) diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go new file mode 100644 index 000000000..e90c78464 --- /dev/null +++ b/internal/commands/helpers.go @@ -0,0 +1,28 @@ +package commands + +import ( + "errors" + + "github.com/authelia/authelia/v4/internal/storage" +) + +func getStorageProvider() (provider storage.Provider, err error) { + switch { + case config.Storage.PostgreSQL != nil: + provider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL) + case config.Storage.MySQL != nil: + provider = storage.NewMySQLProvider(*config.Storage.MySQL) + case config.Storage.Local != nil: + provider = storage.NewSQLiteProvider(config.Storage.Local.Path) + default: + return nil, errors.New("no storage provider configured") + } + + if (config.Storage.MySQL != nil && config.Storage.PostgreSQL != nil) || + (config.Storage.MySQL != nil && config.Storage.Local != nil) || + (config.Storage.PostgreSQL != nil && config.Storage.Local != nil) { + return nil, errors.New("multiple storage providers are configured but should only configure one") + } + + return provider, err +} diff --git a/internal/commands/root.go b/internal/commands/root.go index 7d8f4be84..b3c96641a 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -13,6 +13,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/notification" "github.com/authelia/authelia/v4/internal/ntp" "github.com/authelia/authelia/v4/internal/oidc" @@ -46,6 +47,7 @@ func NewRootCmd() (cmd *cobra.Command) { newCompletionCmd(), NewHashPasswordCmd(), NewRSACmd(), + NewStorageCmd(), newValidateConfigCmd(), ) @@ -101,9 +103,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL) case config.Storage.Local != nil: storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) - default: - // TODO: Add storage provider startup check and remove this. - errors = append(errors, fmt.Errorf("unrecognized storage provider")) } var ( @@ -162,6 +161,12 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid err error ) + if err = doStartupCheck(logger, "storage", providers.StorageProvider, false); err != nil { + logger.Errorf("Failure running the storage provider startup check: %+v", err) + + failures = append(failures, "storage") + } + if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil { logger.Errorf("Failure running the user provider startup check: %+v", err) @@ -187,7 +192,7 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid } } -func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) { +func doStartupCheck(logger *logrus.Logger, name string, provider models.StartupCheck, disabled bool) (err error) { if disabled { logger.Debugf("%s provider: startup check skipped as it is disabled", name) return nil @@ -197,7 +202,7 @@ func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.Pro return fmt.Errorf("unrecognized provider or it is not configured properly") } - if err = provider.StartupCheck(logger); err != nil { + if err = provider.StartupCheck(); err != nil { return err } diff --git a/internal/commands/storage.go b/internal/commands/storage.go new file mode 100644 index 000000000..f6b62e5a8 --- /dev/null +++ b/internal/commands/storage.go @@ -0,0 +1,126 @@ +package commands + +import ( + "github.com/spf13/cobra" +) + +// NewStorageCmd returns a new storage *cobra.Command. +func NewStorageCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "storage", + Short: "Manage the Authelia storage", + Args: cobra.NoArgs, + PersistentPreRunE: storagePersistentPreRunE, + } + + cmd.PersistentFlags().StringSliceP("config", "c", []string{"config.yml"}, "configuration file to load for the storage migration") + + cmd.PersistentFlags().String("sqlite.path", "", "the SQLite database path") + + cmd.PersistentFlags().String("mysql.host", "", "the MySQL hostname") + cmd.PersistentFlags().Int("mysql.port", 3306, "the MySQL port") + cmd.PersistentFlags().String("mysql.database", "authelia", "the MySQL database name") + cmd.PersistentFlags().String("mysql.username", "authelia", "the MySQL username") + cmd.PersistentFlags().String("mysql.password", "", "the MySQL password") + + cmd.PersistentFlags().String("postgres.host", "", "the PostgreSQL hostname") + cmd.PersistentFlags().Int("postgres.port", 5432, "the PostgreSQL port") + cmd.PersistentFlags().String("postgres.database", "authelia", "the PostgreSQL database name") + cmd.PersistentFlags().String("postgres.username", "authelia", "the PostgreSQL username") + cmd.PersistentFlags().String("postgres.password", "", "the PostgreSQL password") + + cmd.AddCommand( + newStorageMigrateCmd(), + newStorageSchemaInfoCmd(), + ) + + return cmd +} + +func newStorageSchemaInfoCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "schema-info", + Short: "Show the storage information", + RunE: storageSchemaInfoRunE, + } + + return cmd +} + +// NewMigrationCmd returns a new Migration Cmd. +func newStorageMigrateCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "migrate", + Short: "Perform or list migrations", + Args: cobra.NoArgs, + } + + cmd.AddCommand( + newStorageMigrateUpCmd(), newStorageMigrateDownCmd(), + newStorageMigrateListUpCmd(), newStorageMigrateListDownCmd(), + newStorageMigrateHistoryCmd(), + ) + + return cmd +} + +func newStorageMigrateHistoryCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "history", + Short: "Show migration history", + Args: cobra.NoArgs, + RunE: storageMigrateHistoryRunE, + } + + return cmd +} + +func newStorageMigrateListUpCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "list-up", + Short: "List the up migrations available", + Args: cobra.NoArgs, + RunE: newStorageMigrateListRunE(true), + } + + return cmd +} + +func newStorageMigrateListDownCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "list-down", + Short: "List the down migrations available", + Args: cobra.NoArgs, + RunE: newStorageMigrateListRunE(false), + } + + return cmd +} + +func newStorageMigrateUpCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: storageMigrateDirectionUp, + Short: "Perform a migration up", + Args: cobra.NoArgs, + RunE: newStorageMigrationRunE(true), + } + + cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to, by default this is the latest version") + + return cmd +} + +func newStorageMigrateDownCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: storageMigrateDirectionDown, + Short: "Perform a migration down", + Args: cobra.NoArgs, + RunE: newStorageMigrationRunE(false), + } + + cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to") + cmd.Flags().Bool("pre1", false, "sets pre1 as the version to migrate to") + cmd.Flags().Bool("destroy-data", false, "confirms you want to destroy data with this migration") + + return cmd +} diff --git a/internal/commands/storage_run.go b/internal/commands/storage_run.go new file mode 100644 index 000000000..681683401 --- /dev/null +++ b/internal/commands/storage_run.go @@ -0,0 +1,291 @@ +package commands + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + + "github.com/spf13/cobra" + + "github.com/authelia/authelia/v4/internal/configuration" + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/configuration/validator" + "github.com/authelia/authelia/v4/internal/storage" +) + +func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) { + configs, err := cmd.Flags().GetStringSlice("config") + if err != nil { + return err + } + + sources := make([]configuration.Source, 0, len(configs)+3) + + if cmd.Flags().Changed("config") { + for _, configFile := range configs { + if _, err := os.Stat(configFile); os.IsNotExist(err) { + return fmt.Errorf("could not load the provided configuration file %s: %w", configFile, err) + } + + sources = append(sources, configuration.NewYAMLFileSource(configFile)) + } + } else { + if _, err := os.Stat(configs[0]); err == nil { + sources = append(sources, configuration.NewYAMLFileSource(configs[0])) + } + } + + mapping := map[string]string{ + "sqlite.path": "storage.local.path", + "mysql.host": "storage.mysql.host", + "mysql.port": "storage.mysql.port", + "mysql.database": "storage.mysql.database", + "mysql.username": "storage.mysql.username", + "mysql.password": "storage.mysql.password", + "postgres.host": "storage.postgres.host", + "postgres.port": "storage.postgres.port", + "postgres.database": "storage.postgres.database", + "postgres.username": "storage.postgres.username", + "postgres.password": "storage.postgres.password", + "postgres.schema": "storage.postgres.schema", + } + + sources = append(sources, configuration.NewEnvironmentSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter)) + sources = append(sources, configuration.NewSecretsSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter)) + sources = append(sources, configuration.NewCommandLineSourceWithMapping(cmd.Flags(), mapping, true, false)) + + val := schema.NewStructValidator() + + config = &schema.Configuration{} + + _, err = configuration.LoadAdvanced(val, "storage", &config.Storage, sources...) + if err != nil { + return err + } + + if val.HasErrors() { + var finalErr error + + for i, err := range val.Errors() { + if i == 0 { + finalErr = err + continue + } + + finalErr = fmt.Errorf("%w, %v", finalErr, err) + } + + return finalErr + } + + validator.ValidateStorage(config.Storage, val) + + if val.HasErrors() { + var finalErr error + + for i, err := range val.Errors() { + if i == 0 { + finalErr = err + continue + } + + finalErr = fmt.Errorf("%w, %v", finalErr, err) + } + + return finalErr + } + + return nil +} + +func storageMigrateHistoryRunE(_ *cobra.Command, _ []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + ) + + provider, err = getStorageProvider() + if err != nil { + return err + } + + migrations, err := provider.SchemaMigrationHistory(ctx) + if err != nil { + return err + } + + if len(migrations) == 0 { + return errors.New("no migration history found which may indicate a broken schema") + } + + fmt.Printf("Migration History:\n\nID\tDate\t\t\t\tBefore\tAfter\tAuthelia Version\n") + + for _, m := range migrations { + fmt.Printf("%d\t%s\t%d\t%d\t%s\n", m.ID, m.Applied.Format("2006-01-02 15:04:05 -0700"), m.Before, m.After, m.Version) + } + + return nil +} + +func newStorageMigrateListRunE(up bool) func(cmd *cobra.Command, args []string) (err error) { + return func(cmd *cobra.Command, args []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + migrations []storage.SchemaMigration + directionStr string + ) + + provider, err = getStorageProvider() + if err != nil { + return err + } + + if up { + migrations, err = provider.SchemaMigrationsUp(ctx, 0) + directionStr = "Up" + } else { + migrations, err = provider.SchemaMigrationsDown(ctx, 0) + directionStr = "Down" + } + + if err != nil { + if err.Error() == "cannot migrate to the same version as prior" { + fmt.Printf("No %s migrations found\n", directionStr) + + return nil + } + + return err + } + + if len(migrations) == 0 { + fmt.Printf("Storage Schema Migration List (%s)\n\nNo Migrations Available\n", directionStr) + } else { + fmt.Printf("Storage Schema Migration List (%s)\n\nVersion\t\tDescription\n", directionStr) + + for _, migration := range migrations { + fmt.Printf("%d\t\t%s\n", migration.Version, migration.Name) + } + } + + return nil + } +} + +func newStorageMigrationRunE(up bool) func(cmd *cobra.Command, args []string) (err error) { + return func(cmd *cobra.Command, args []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + ) + + provider, err = getStorageProvider() + if err != nil { + return err + } + + target, err := cmd.Flags().GetInt("target") + if err != nil { + return err + } + + switch { + case up: + switch cmd.Flags().Changed("target") { + case true: + return provider.SchemaMigrate(ctx, true, target) + default: + return provider.SchemaMigrate(ctx, true, storage.SchemaLatest) + } + default: + if !cmd.Flags().Changed("target") { + return errors.New("must set target") + } + + if err = storageMigrateDownConfirmDestroy(cmd); err != nil { + return err + } + + pre1, err := cmd.Flags().GetBool("pre1") + if err != nil { + return err + } + + switch { + case pre1: + return provider.SchemaMigrate(ctx, false, -1) + default: + return provider.SchemaMigrate(ctx, false, target) + } + } + } +} + +func storageMigrateDownConfirmDestroy(cmd *cobra.Command) (err error) { + destroy, err := cmd.Flags().GetBool("destroy-data") + if err != nil { + return err + } + + if !destroy { + fmt.Printf("Schema Down Migrations may DESTROY data, type 'DESTROY' and press return to continue: ") + + var text string + + _, _ = fmt.Scanln(&text) + + if text != "DESTROY" { + return errors.New("cancelling down migration due to user not accepting data destruction") + } + } + + return nil +} + +func storageSchemaInfoRunE(_ *cobra.Command, _ []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + upgradeStr string + tablesStr string + ) + + provider, err = getStorageProvider() + if err != nil { + return err + } + + version, err := provider.SchemaVersion(ctx) + if err != nil && err.Error() != "unknown schema state" { + return err + } + + tables, err := provider.SchemaTables(ctx) + if err != nil { + return err + } + + if len(tables) == 0 { + tablesStr = "N/A" + } else { + tablesStr = strings.Join(tables, ", ") + } + + latest, err := provider.SchemaLatestVersion() + if err != nil { + return err + } + + if latest > version { + upgradeStr = fmt.Sprintf("yes - version %d", latest) + } else { + upgradeStr = "no" + } + + fmt.Printf("Schema Version: %s\nSchema Upgrade Available: %s\nSchema Tables: %s\n", storage.SchemaVersionToString(version), upgradeStr, tablesStr) + + return nil +} diff --git a/internal/configuration/koanf_callbacks.go b/internal/configuration/koanf_callbacks.go index 3823f3911..384b87c88 100644 --- a/internal/configuration/koanf_callbacks.go +++ b/internal/configuration/koanf_callbacks.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" + "github.com/spf13/pflag" + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/validator" "github.com/authelia/authelia/v4/internal/utils" @@ -48,3 +50,25 @@ func koanfEnvironmentSecretsCallback(keyMap map[string]string, validator *schema return k, v } } + +func koanfCommandLineWithMappingCallback(mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) func(flag *pflag.Flag) (string, interface{}) { + return func(flag *pflag.Flag) (string, interface{}) { + if !includeUnchangedKeys && !flag.Changed { + return "", nil + } + + if actualKey, ok := mapping[flag.Name]; ok { + return actualKey, flag.Value.String() + } + + if includeValidKeys { + formattedKey := strings.ReplaceAll(flag.Name, "-", "_") + + if utils.IsStringInSlice(formattedKey, validator.ValidKeys) { + return formattedKey, flag.Value.String() + } + } + + return "", nil + } +} diff --git a/internal/configuration/provider.go b/internal/configuration/provider.go index a8a94cbfb..5317991cc 100644 --- a/internal/configuration/provider.go +++ b/internal/configuration/provider.go @@ -11,8 +11,17 @@ import ( // Load the configuration given the provided options and sources. func Load(val *schema.StructValidator, sources ...Source) (keys []string, configuration *schema.Configuration, err error) { + configuration = &schema.Configuration{} + + keys, err = LoadAdvanced(val, "", configuration, sources...) + + return keys, configuration, err +} + +// LoadAdvanced is intended to give more flexibility over loading a particular path to a specific interface. +func LoadAdvanced(val *schema.StructValidator, path string, result interface{}, sources ...Source) (keys []string, err error) { if val == nil { - return keys, configuration, errNoValidator + return keys, errNoValidator } ko := koanf.NewWithConf(koanf.Conf{ @@ -22,14 +31,12 @@ func Load(val *schema.StructValidator, sources ...Source) (keys []string, config err = loadSources(ko, val, sources...) if err != nil { - return ko.Keys(), configuration, err + return ko.Keys(), err } - configuration = &schema.Configuration{} + unmarshal(ko, val, path, result) - unmarshal(ko, val, "", configuration) - - return ko.Keys(), configuration, nil + return ko.Keys(), nil } func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o interface{}) { diff --git a/internal/configuration/schema/validator.go b/internal/configuration/schema/validator.go index a37e3eec5..ae16911f7 100644 --- a/internal/configuration/schema/validator.go +++ b/internal/configuration/schema/validator.go @@ -1,12 +1,5 @@ package schema -import ( - "fmt" - "reflect" - - "github.com/Workiva/go-datastructures/queue" -) - // ErrorContainer represents a container where we can add errors and retrieve them. type ErrorContainer interface { Push(err error) @@ -17,100 +10,6 @@ type ErrorContainer interface { Warnings() []error } -// Validator represents the validator interface. -type Validator struct { - errors map[string][]error -} - -// NewValidator create a validator. -func NewValidator() *Validator { - validator := new(Validator) - validator.errors = make(map[string][]error) - - return validator -} - -// QueueItem an item representing a struct field and its path. -type QueueItem struct { - value reflect.Value - path string -} - -func (v *Validator) validateOne(item QueueItem, q *queue.Queue) error { //nolint:unparam - if item.value.Type().Kind() == reflect.Ptr { - if item.value.IsNil() { - return nil - } - - elem := item.value.Elem() - - q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. - value: elem, - path: item.path, - }) - } else if item.value.Kind() == reflect.Struct { - numFields := item.value.Type().NumField() - - validateFn := item.value.Addr().MethodByName("Validate") - - if validateFn.IsValid() { - structValidator := NewStructValidator() - validateFn.Call([]reflect.Value{reflect.ValueOf(structValidator)}) - v.errors[item.path] = structValidator.Errors() - } - - for i := 0; i < numFields; i++ { - field := item.value.Type().Field(i) - value := item.value.Field(i) - - q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. - value: value, - path: item.path + "." + field.Name, - }) - } - } - - return nil -} - -// Validate validate a struct. -func (v *Validator) Validate(s interface{}) error { - q := queue.New(40) - q.Put(QueueItem{value: reflect.ValueOf(s), path: "root"}) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. - - for !q.Empty() { - val, err := q.Get(1) - if err != nil { - return err - } - - item, ok := val[0].(QueueItem) - if !ok { - return fmt.Errorf("Cannot convert item into QueueItem") - } - - v.validateOne(item, q) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. - } - - return nil -} - -// PrintErrors display the errors thrown during validation. -func (v *Validator) PrintErrors() { - for path, errs := range v.errors { - fmt.Printf("Errors at %s:\n", path) - - for _, err := range errs { - fmt.Printf("--> %s\n", err) - } - } -} - -// Errors return the errors thrown during validation. -func (v *Validator) Errors() map[string][]error { - return v.errors -} - // StructValidator is a validator for structs. type StructValidator struct { errors []error diff --git a/internal/configuration/schema/validator_test.go b/internal/configuration/schema/validator_test.go index 0ded4b7b7..f1cedc3ab 100644 --- a/internal/configuration/schema/validator_test.go +++ b/internal/configuration/schema/validator_test.go @@ -49,43 +49,6 @@ func (ts *TestStruct) Validate(validator *schema.StructValidator) { } } -func TestValidator(t *testing.T) { - validator := schema.NewValidator() - - s := TestStruct{ - MustBe10: 5, - NotEmpty: "", - NestedPtr: &TestNestedStruct{}, - } - - err := validator.Validate(&s) - if err != nil { - panic(err) - } - - errs := validator.Errors() - assert.Equal(t, 4, len(errs)) - - assert.Equal(t, 2, len(errs["root"])) - assert.ElementsMatch(t, []error{ - fmt.Errorf("MustBe10 must be 10"), - fmt.Errorf("NotEmpty must not be empty")}, errs["root"]) - - assert.Equal(t, 1, len(errs["root.Nested"])) - assert.ElementsMatch(t, []error{ - fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested"]) - - assert.Equal(t, 1, len(errs["root.Nested2"])) - assert.ElementsMatch(t, []error{ - fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested2"]) - - assert.Equal(t, 1, len(errs["root.NestedPtr"])) - assert.ElementsMatch(t, []error{ - fmt.Errorf("MustBe5 must be 5")}, errs["root.NestedPtr"]) - - assert.Equal(t, "xyz", s.SetDefault) -} - func TestStructValidator(t *testing.T) { validator := schema.NewStructValidator() s := TestStruct{ diff --git a/internal/configuration/sources.go b/internal/configuration/sources.go index e10951517..7506b2f03 100644 --- a/internal/configuration/sources.go +++ b/internal/configuration/sources.go @@ -8,6 +8,8 @@ import ( "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/providers/posflag" + "github.com/spf13/pflag" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/validator" @@ -112,6 +114,37 @@ func (s *SecretsSource) Load(val *schema.StructValidator) (err error) { return s.koanf.Load(env.ProviderWithValue(s.prefix, constDelimiter, koanfEnvironmentSecretsCallback(keyMap, val)), nil) } +// NewCommandLineSourceWithMapping creates a new command line configuration source with a map[string]string which converts +// flag names into other config key names. If includeValidKeys is true we also allow any flag with a name which matches +// the list of valid keys into the koanf.Koanf, otherwise everything not in the map is skipped. Unchanged flags are also +// skipped unless includeUnchangedKeys is set to true. +func NewCommandLineSourceWithMapping(flags *pflag.FlagSet, mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) (source *CommandLineSource) { + return &CommandLineSource{ + koanf: koanf.New(constDelimiter), + flags: flags, + callback: koanfCommandLineWithMappingCallback(mapping, includeValidKeys, includeUnchangedKeys), + } +} + +// Name of the Source. +func (s CommandLineSource) Name() (name string) { + return "command-line" +} + +// Merge the CommandLineSource koanf.Koanf into the provided one. +func (s *CommandLineSource) Merge(ko *koanf.Koanf, val *schema.StructValidator) (err error) { + return ko.Merge(s.koanf) +} + +// Load the Source into the YAMLFileSource koanf.Koanf. +func (s *CommandLineSource) Load(_ *schema.StructValidator) (err error) { + if s.callback != nil { + return s.koanf.Load(posflag.ProviderWithFlag(s.flags, ".", s.koanf, s.callback), nil) + } + + return s.koanf.Load(posflag.Provider(s.flags, ".", s.koanf), nil) +} + // NewDefaultSources returns a slice of Source configured to load from specified YAML files. func NewDefaultSources(filePaths []string, prefix, delimiter string) (sources []Source) { fileSources := NewYAMLFileSources(filePaths) diff --git a/internal/configuration/types.go b/internal/configuration/types.go index c2d029c5e..fe5787be5 100644 --- a/internal/configuration/types.go +++ b/internal/configuration/types.go @@ -2,6 +2,7 @@ package configuration import ( "github.com/knadh/koanf" + "github.com/spf13/pflag" "github.com/authelia/authelia/v4/internal/configuration/schema" ) @@ -32,3 +33,10 @@ type SecretsSource struct { prefix string delimiter string } + +// CommandLineSource loads configuration from the command line flags. +type CommandLineSource struct { + koanf *koanf.Koanf + flags *pflag.FlagSet + callback func(flag *pflag.Flag) (string, interface{}) +} diff --git a/internal/handlers/const.go b/internal/handlers/const.go index 5c4609b7c..f8a1bea28 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -73,6 +73,12 @@ const ( pathOpenIDConnectConsent = "/api/oidc/consent" ) +const ( + totpAlgoSHA1 = "SHA1" + totpAlgoSHA256 = "SHA256" + totpAlgoSHA512 = "SHA512" +) + const ( accept = "accept" reject = "reject" diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go index b131c41dc..a6023b6f4 100644 --- a/internal/handlers/handler_firstfactor.go +++ b/internal/handlers/handler_firstfactor.go @@ -77,7 +77,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware return } - bannedUntil, err := ctx.Providers.Regulator.Regulate(bodyJSON.Username) + bannedUntil, err := ctx.Providers.Regulator.Regulate(ctx, bodyJSON.Username) if err != nil { if err == regulation.ErrUserIsBanned { @@ -95,7 +95,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware if err != nil { ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username) - if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil { + if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil { ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error()) } @@ -107,7 +107,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware if !userPasswordOk { ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username) - if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil { + if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil { ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error()) } @@ -117,7 +117,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware } ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username) - err = ctx.Providers.Regulator.Mark(bodyJSON.Username, true) + err = ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, true) if err != nil { handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to mark authentication: %s", err.Error()), messageAuthenticationFailed) diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index 4902ebbbc..bcfc5f34a 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -58,7 +58,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{ + AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "test", Successful: false, Time: s.mock.Clock.Now(), @@ -83,7 +83,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{ + AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "test", Successful: false, Time: s.mock.Clock.Now(), @@ -106,7 +106,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) s.mock.UserProviderMock. @@ -133,7 +133,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(fmt.Errorf("failed")) s.mock.Ctx.Request.SetBodyString(`{ @@ -164,7 +164,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) s.mock.Ctx.Request.SetBodyString(`{ @@ -204,7 +204,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) s.mock.Ctx.Request.SetBodyString(`{ @@ -248,7 +248,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) s.mock.Ctx.Request.SetBodyString(`{ @@ -306,7 +306,7 @@ func (s *FirstFactorRedirectionSuite) SetupTest() { s.mock.StorageProviderMock. EXPECT(). - AppendAuthenticationLog(gomock.Any()). + AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) } diff --git a/internal/handlers/handler_register_totp.go b/internal/handlers/handler_register_totp.go index 784ae7378..cdd88305d 100644 --- a/internal/handlers/handler_register_totp.go +++ b/internal/handlers/handler_register_totp.go @@ -3,9 +3,11 @@ package handlers import ( "fmt" + "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/session" ) @@ -37,11 +39,15 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle }) func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { + algorithm := otp.AlgorithmSHA1 + key, err := totp.Generate(totp.GenerateOpts{ Issuer: ctx.Configuration.TOTP.Issuer, AccountName: username, - SecretSize: 32, Period: uint(ctx.Configuration.TOTP.Period), + SecretSize: 32, + Digits: otp.Digits(6), + Algorithm: algorithm, }) if err != nil { @@ -49,7 +55,15 @@ func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username strin return } - err = ctx.Providers.StorageProvider.SaveTOTPSecret(username, key.Secret()) + config := models.TOTPConfiguration{ + Username: username, + Algorithm: otpAlgoToString(algorithm), + Digits: 6, + Secret: key.Secret(), + Period: key.Period(), + } + + err = ctx.Providers.StorageProvider.SaveTOTPConfiguration(ctx, config) if err != nil { ctx.Error(fmt.Errorf("unable to save TOTP secret in DB: %s", err), messageUnableToRegisterOneTimePassword) return diff --git a/internal/handlers/handler_register_u2f_step1_test.go b/internal/handlers/handler_register_u2f_step1_test.go index b779d31c7..1706eada5 100644 --- a/internal/handlers/handler_register_u2f_step1_test.go +++ b/internal/handlers/handler_register_u2f_step1_test.go @@ -57,11 +57,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerificationToken(gomock.Eq(token)). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) @@ -77,11 +77,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerificationToken(gomock.Eq(token)). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) diff --git a/internal/handlers/handler_register_u2f_step2.go b/internal/handlers/handler_register_u2f_step2.go index d36090a76..dd2b70ff9 100644 --- a/internal/handlers/handler_register_u2f_step2.go +++ b/internal/handlers/handler_register_u2f_step2.go @@ -7,6 +7,7 @@ import ( "github.com/tstranex/u2f" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/models" ) // SecondFactorU2FRegister handler validating the client has successfully validated the challenge @@ -45,7 +46,12 @@ func SecondFactorU2FRegister(ctx *middlewares.AutheliaCtx) { ctx.Logger.Debugf("Register U2F device for user %s", userSession.Username) publicKey := elliptic.Marshal(elliptic.P256(), registration.PubKey.X, registration.PubKey.Y) - err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, registration.KeyHandle, publicKey) + + err = ctx.Providers.StorageProvider.SaveU2FDevice(ctx, models.U2FDevice{ + Username: userSession.Username, + KeyHandle: registration.KeyHandle, + PublicKey: publicKey}, + ) if err != nil { ctx.Error(fmt.Errorf("unable to register U2F device for user %s: %v", userSession.Username, err), messageUnableToRegisterSecurityKey) diff --git a/internal/handlers/handler_sign_totp.go b/internal/handlers/handler_sign_totp.go index daf69e61a..71c66879a 100644 --- a/internal/handlers/handler_sign_totp.go +++ b/internal/handlers/handler_sign_totp.go @@ -19,13 +19,13 @@ func SecondFactorTOTPPost(totpVerifier TOTPVerifier) middlewares.RequestHandler userSession := ctx.GetSession() - secret, err := ctx.Providers.StorageProvider.LoadTOTPSecret(userSession.Username) + config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username) if err != nil { handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to load TOTP secret: %s", err), messageMFAValidationFailed) return } - isValid, err := totpVerifier.Verify(requestBody.Token, secret) + isValid, err := totpVerifier.Verify(config, requestBody.Token) if err != nil { handleAuthenticationUnauthorized(ctx, fmt.Errorf("error occurred during OTP validation for user %s: %s", userSession.Username, err), messageMFAValidationFailed) return diff --git a/internal/handlers/handler_sign_totp_test.go b/internal/handlers/handler_sign_totp_test.go index 2220fea59..426c314df 100644 --- a/internal/handlers/handler_sign_totp_test.go +++ b/internal/handlers/handler_sign_totp_test.go @@ -11,6 +11,7 @@ import ( "github.com/tstranex/u2f" "github.com/authelia/authelia/v4/internal/mocks" + "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/session" ) @@ -37,12 +38,14 @@ func (s *HandlerSignTOTPSuite) TearDownTest() { func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { verifier := NewMockTOTPVerifier(s.mock.Ctrl) + config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"} + s.mock.StorageProviderMock.EXPECT(). - LoadTOTPSecret(gomock.Any()). - Return("secret", nil) + LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). + Return(&config, nil) verifier.EXPECT(). - Verify(gomock.Eq("abc"), gomock.Eq("secret")). + Verify(gomock.Eq(&config), gomock.Eq("abc")). Return(true, nil) s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL @@ -62,12 +65,14 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { verifier := NewMockTOTPVerifier(s.mock.Ctrl) + config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"} + s.mock.StorageProviderMock.EXPECT(). - LoadTOTPSecret(gomock.Any()). - Return("secret", nil) + LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). + Return(&config, nil) verifier.EXPECT(). - Verify(gomock.Eq("abc"), gomock.Eq("secret")). + Verify(gomock.Eq(&config), gomock.Eq("abc")). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ @@ -83,12 +88,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() { verifier := NewMockTOTPVerifier(s.mock.Ctrl) + config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"} + s.mock.StorageProviderMock.EXPECT(). - LoadTOTPSecret(gomock.Any()). - Return("secret", nil) + LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). + Return(&config, nil) verifier.EXPECT(). - Verify(gomock.Eq("abc"), gomock.Eq("secret")). + Verify(gomock.Eq(&config), gomock.Eq("abc")). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ @@ -108,11 +115,11 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() { verifier := NewMockTOTPVerifier(s.mock.Ctrl) s.mock.StorageProviderMock.EXPECT(). - LoadTOTPSecret(gomock.Any()). - Return("secret", nil) + LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). + Return(&models.TOTPConfiguration{Secret: "secret"}, nil) verifier.EXPECT(). - Verify(gomock.Eq("abc"), gomock.Eq("secret")). + Verify(gomock.Eq(&models.TOTPConfiguration{Secret: "secret"}), gomock.Eq("abc")). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ @@ -129,12 +136,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() { func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFixation() { verifier := NewMockTOTPVerifier(s.mock.Ctrl) + config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"} + s.mock.StorageProviderMock.EXPECT(). - LoadTOTPSecret(gomock.Any()). - Return("secret", nil) + LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). + Return(&config, nil) verifier.EXPECT(). - Verify(gomock.Eq("abc"), gomock.Eq("secret")). + Verify(gomock.Eq(&config), gomock.Eq("abc")). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ diff --git a/internal/handlers/handler_sign_u2f_step1.go b/internal/handlers/handler_sign_u2f_step1.go index a96b7ec83..613d2372c 100644 --- a/internal/handlers/handler_sign_u2f_step1.go +++ b/internal/handlers/handler_sign_u2f_step1.go @@ -34,7 +34,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) { } userSession := ctx.GetSession() - keyHandleBytes, publicKeyBytes, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username) + device, err := ctx.Providers.StorageProvider.LoadU2FDevice(ctx, userSession.Username) if err != nil { if err == storage.ErrNoU2FDeviceHandle { @@ -48,16 +48,16 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) { } var registration u2f.Registration - registration.KeyHandle = keyHandleBytes - x, y := elliptic.Unmarshal(elliptic.P256(), publicKeyBytes) + registration.KeyHandle = device.KeyHandle + x, y := elliptic.Unmarshal(elliptic.P256(), device.PublicKey) registration.PubKey.Curve = elliptic.P256() registration.PubKey.X = x registration.PubKey.Y = y // Save the challenge and registration for use in next request userSession.U2FRegistration = &session.U2FRegistration{ - KeyHandle: keyHandleBytes, - PublicKey: publicKeyBytes, + KeyHandle: device.KeyHandle, + PublicKey: device.PublicKey, } userSession.U2FChallenge = challenge err = ctx.SaveSession(userSession) diff --git a/internal/handlers/handler_user_info.go b/internal/handlers/handler_user_info.go index f45f3300e..f184a9656 100644 --- a/internal/handlers/handler_user_info.go +++ b/internal/handlers/handler_user_info.go @@ -3,97 +3,25 @@ package handlers import ( "fmt" "strings" - "sync" - - "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/middlewares" - "github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/utils" ) -func loadInfo(username string, storageProvider storage.Provider, userInfo *UserInfo, logger *logrus.Entry) []error { - var wg sync.WaitGroup - - wg.Add(3) - - errors := make([]error, 0) - - go func() { - defer wg.Done() - - method, err := storageProvider.LoadPreferred2FAMethod(username) - if err != nil { - errors = append(errors, err) - logger.Error(err) - - return - } - - if method == "" { - userInfo.Method = authentication.PossibleMethods[0] - } else { - userInfo.Method = method - } - }() - - go func() { - defer wg.Done() - - _, _, err := storageProvider.LoadU2FDeviceHandle(username) - if err != nil { - if err == storage.ErrNoU2FDeviceHandle { - return - } - - errors = append(errors, err) - logger.Error(err) - - return - } - - userInfo.HasU2F = true - }() - - go func() { - defer wg.Done() - - _, err := storageProvider.LoadTOTPSecret(username) - if err != nil { - if err == storage.ErrNoTOTPSecret { - return - } - - errors = append(errors, err) - logger.Error(err) - - return - } - - userInfo.HasTOTP = true - }() - - wg.Wait() - - return errors -} - // UserInfoGet get the info related to the user identified by the session. func UserInfoGet(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() - userInfo := UserInfo{} - errors := loadInfo(userSession.Username, ctx.Providers.StorageProvider, &userInfo, ctx.Logger) - - if len(errors) > 0 { - ctx.Error(fmt.Errorf("unable to load user information"), messageOperationFailed) + userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username) + if err != nil { + ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed) return } userInfo.DisplayName = userSession.DisplayName - err := ctx.SetJSONBody(userInfo) + err = ctx.SetJSONBody(userInfo) if err != nil { ctx.Logger.Errorf("Unable to set user info response in body: %s", err) } @@ -121,7 +49,7 @@ func MethodPreferencePost(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() ctx.Logger.Debugf("Save new preferred 2FA method of user %s to %s", userSession.Username, bodyJSON.Method) - err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(userSession.Username, bodyJSON.Method) + err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, bodyJSON.Method) if err != nil { ctx.Error(fmt.Errorf("unable to save new preferred 2FA method: %s", err), messageOperationFailed) diff --git a/internal/handlers/handler_user_info_test.go b/internal/handlers/handler_user_info_test.go index db5db2e52..f64d36d34 100644 --- a/internal/handlers/handler_user_info_test.go +++ b/internal/handlers/handler_user_info_test.go @@ -1,6 +1,8 @@ package handlers import ( + "database/sql" + "errors" "fmt" "testing" @@ -11,7 +13,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/authelia/authelia/v4/internal/mocks" - "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/models" ) type FetchSuite struct { @@ -33,62 +35,59 @@ func (s *FetchSuite) TearDownTest() { s.mock.Close() } -func setPreferencesExpectations(preferences UserInfo, provider *storage.MockProvider) { - provider. - EXPECT(). - LoadPreferred2FAMethod(gomock.Eq("john")). - Return(preferences.Method, nil) - - if preferences.HasU2F { - u2fData := []byte("abc") - provider. - EXPECT(). - LoadU2FDeviceHandle(gomock.Eq("john")). - Return(u2fData, u2fData, nil) - } else { - provider. - EXPECT(). - LoadU2FDeviceHandle(gomock.Eq("john")). - Return(nil, nil, storage.ErrNoU2FDeviceHandle) - } - - if preferences.HasTOTP { - totpSecret := "secret" - provider. - EXPECT(). - LoadTOTPSecret(gomock.Eq("john")). - Return(totpSecret, nil) - } else { - provider. - EXPECT(). - LoadTOTPSecret(gomock.Eq("john")). - Return("", storage.ErrNoTOTPSecret) - } +type expectedResponse struct { + db models.UserInfo + api *models.UserInfo + err error } func TestMethodSetToU2F(t *testing.T) { - table := []UserInfo{ + expectedResponses := []expectedResponse{ { - Method: "totp", + db: models.UserInfo{ + Method: "totp", + }, + err: nil, }, { - Method: "u2f", - HasU2F: true, - HasTOTP: true, + db: models.UserInfo{ + Method: "u2f", + HasU2F: true, + HasTOTP: true, + }, + err: nil, }, { - Method: "u2f", - HasU2F: true, - HasTOTP: false, + db: models.UserInfo{ + Method: "u2f", + HasU2F: true, + HasTOTP: false, + }, + err: nil, }, { - Method: "mobile_push", - HasU2F: false, - HasTOTP: false, + db: models.UserInfo{ + Method: "mobile_push", + HasU2F: false, + HasTOTP: false, + }, + err: nil, + }, + { + db: models.UserInfo{}, + err: sql.ErrNoRows, + }, + { + db: models.UserInfo{}, + err: errors.New("invalid thing"), }, } - for _, expectedPreferences := range table { + for _, resp := range expectedResponses { + if resp.api == nil { + resp.api = &resp.db + } + mock := mocks.NewMockAutheliaCtx(t) // Set the initial user session. userSession := mock.Ctx.GetSession() @@ -97,64 +96,57 @@ func TestMethodSetToU2F(t *testing.T) { err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) - setPreferencesExpectations(expectedPreferences, mock.StorageProviderMock) + mock.StorageProviderMock. + EXPECT(). + LoadUserInfo(mock.Ctx, gomock.Eq("john")). + Return(resp.db, resp.err) + UserInfoGet(mock.Ctx) - actualPreferences := UserInfo{} - mock.GetResponseData(t, &actualPreferences) + if resp.err == nil { + t.Run("expected status code", func(t *testing.T) { + assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) + }) - t.Run("expected method", func(t *testing.T) { - assert.Equal(t, expectedPreferences.Method, actualPreferences.Method) - }) + actualPreferences := models.UserInfo{} - t.Run("registered u2f", func(t *testing.T) { - assert.Equal(t, expectedPreferences.HasU2F, actualPreferences.HasU2F) - }) + mock.GetResponseData(t, &actualPreferences) + + t.Run("expected method", func(t *testing.T) { + assert.Equal(t, resp.api.Method, actualPreferences.Method) + }) + + t.Run("registered u2f", func(t *testing.T) { + assert.Equal(t, resp.api.HasU2F, actualPreferences.HasU2F) + }) + + t.Run("registered totp", func(t *testing.T) { + assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP) + }) + } else { + t.Run("expected status code", func(t *testing.T) { + assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) + }) + + errResponse := mock.GetResponseError(t) + + assert.Equal(t, "KO", errResponse.Status) + assert.Equal(t, "Operation failed.", errResponse.Message) + } - t.Run("registered totp", func(t *testing.T) { - assert.Equal(t, expectedPreferences.HasTOTP, actualPreferences.HasTOTP) - }) mock.Close() } } -func (s *FetchSuite) TestShouldGetDefaultPreferenceIfNotInDB() { - s.mock.StorageProviderMock. - EXPECT(). - LoadPreferred2FAMethod(gomock.Eq("john")). - Return("", nil) - - s.mock.StorageProviderMock. - EXPECT(). - LoadU2FDeviceHandle(gomock.Eq("john")). - Return(nil, nil, storage.ErrNoU2FDeviceHandle) - - s.mock.StorageProviderMock. - EXPECT(). - LoadTOTPSecret(gomock.Eq("john")). - Return("", storage.ErrNoTOTPSecret) - - UserInfoGet(s.mock.Ctx) - s.mock.Assert200OK(s.T(), UserInfo{Method: "totp"}) -} - func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() { s.mock.StorageProviderMock.EXPECT(). - LoadPreferred2FAMethod(gomock.Eq("john")). - Return("", fmt.Errorf("Failure")) - - s.mock.StorageProviderMock. - EXPECT(). - LoadU2FDeviceHandle(gomock.Eq("john")) - - s.mock.StorageProviderMock. - EXPECT(). - LoadTOTPSecret(gomock.Eq("john")) + LoadUserInfo(s.mock.Ctx, gomock.Eq("john")). + Return(models.UserInfo{}, fmt.Errorf("failure")) UserInfoGet(s.mock.Ctx) s.mock.Assert200KO(s.T(), "Operation failed.") - assert.Equal(s.T(), "unable to load user information", s.mock.Hook.LastEntry().Message) + assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level) } @@ -220,7 +212,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() { func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() { s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}")) s.mock.StorageProviderMock.EXPECT(). - SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")). + SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")). Return(fmt.Errorf("Failure")) MethodPreferencePost(s.mock.Ctx) @@ -233,7 +225,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() { func (s *SaveSuite) TestShouldReturn200WhenMethodIsSuccessfullySaved() { s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}")) s.mock.StorageProviderMock.EXPECT(). - SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")). + SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")). Return(nil) MethodPreferencePost(s.mock.Ctx) diff --git a/internal/handlers/totp.go b/internal/handlers/totp.go index fe4e7c2a3..510292d68 100644 --- a/internal/handlers/totp.go +++ b/internal/handlers/totp.go @@ -1,15 +1,18 @@ package handlers import ( + "errors" "time" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" + + "github.com/authelia/authelia/v4/internal/models" ) // TOTPVerifier is the interface for verifying TOTPs. type TOTPVerifier interface { - Verify(token, secret string) (bool, error) + Verify(config *models.TOTPConfiguration, token string) (bool, error) } // TOTPVerifierImpl the production implementation for TOTP verification. @@ -19,13 +22,43 @@ type TOTPVerifierImpl struct { } // Verify verifies TOTPs. -func (tv *TOTPVerifierImpl) Verify(token, secret string) (bool, error) { - opts := totp.ValidateOpts{ - Period: tv.Period, - Skew: tv.Skew, - Digits: otp.DigitsSix, - Algorithm: otp.AlgorithmSHA1, +func (tv *TOTPVerifierImpl) Verify(config *models.TOTPConfiguration, token string) (bool, error) { + if config == nil { + return false, errors.New("config not provided") } - return totp.ValidateCustom(token, secret, time.Now().UTC(), opts) + opts := totp.ValidateOpts{ + Period: uint(config.Period), + Skew: tv.Skew, + Digits: otp.Digits(config.Digits), + Algorithm: otpStringToAlgo(config.Algorithm), + } + + return totp.ValidateCustom(token, config.Secret, time.Now().UTC(), opts) +} + +func otpAlgoToString(algorithm otp.Algorithm) (out string) { + switch algorithm { + case otp.AlgorithmSHA1: + return totpAlgoSHA1 + case otp.AlgorithmSHA256: + return totpAlgoSHA256 + case otp.AlgorithmSHA512: + return totpAlgoSHA512 + default: + return "" + } +} + +func otpStringToAlgo(in string) (algorithm otp.Algorithm) { + switch in { + case totpAlgoSHA1: + return otp.AlgorithmSHA1 + case totpAlgoSHA256: + return otp.AlgorithmSHA256 + case totpAlgoSHA512: + return otp.AlgorithmSHA512 + default: + return otp.AlgorithmSHA1 + } } diff --git a/internal/handlers/totp_mock.go b/internal/handlers/totp_mock.go index 1dad3fedf..0da2a4e84 100644 --- a/internal/handlers/totp_mock.go +++ b/internal/handlers/totp_mock.go @@ -5,45 +5,47 @@ package handlers import ( - reflect "reflect" + "reflect" - gomock "github.com/golang/mock/gomock" + "github.com/golang/mock/gomock" + + "github.com/authelia/authelia/v4/internal/models" ) -// MockTOTPVerifier is a mock of TOTPVerifier interface +// MockTOTPVerifier is a mock of TOTPVerifier interface. type MockTOTPVerifier struct { ctrl *gomock.Controller recorder *MockTOTPVerifierMockRecorder } -// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier +// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier. type MockTOTPVerifierMockRecorder struct { mock *MockTOTPVerifier } -// NewMockTOTPVerifier creates a new mock instance +// NewMockTOTPVerifier creates a new mock instance. func NewMockTOTPVerifier(ctrl *gomock.Controller) *MockTOTPVerifier { mock := &MockTOTPVerifier{ctrl: ctrl} mock.recorder = &MockTOTPVerifierMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockTOTPVerifier) EXPECT() *MockTOTPVerifierMockRecorder { return m.recorder } -// Verify mocks base method -func (m *MockTOTPVerifier) Verify(token, secret string) (bool, error) { +// Verify mocks base method. +func (m *MockTOTPVerifier) Verify(arg0 *models.TOTPConfiguration, arg1 string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Verify", token, secret) + ret := m.ctrl.Call(m, "Verify", arg0, arg1) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// Verify indicates an expected call of Verify -func (mr *MockTOTPVerifierMockRecorder) Verify(token, secret interface{}) *gomock.Call { +// Verify indicates an expected call of Verify. +func (mr *MockTOTPVerifierMockRecorder) Verify(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), token, secret) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), arg0, arg1) } diff --git a/internal/handlers/types.go b/internal/handlers/types.go index 2038108a7..4fa4886e9 100644 --- a/internal/handlers/types.go +++ b/internal/handlers/types.go @@ -11,21 +11,6 @@ type MethodList = []string type authorizationMatching int -// UserInfo is the model of user info and second factor preferences. -type UserInfo struct { - // The users display name. - DisplayName string `json:"display_name"` - - // The preferred 2FA method. - Method string `json:"method" valid:"required"` - - // True if a security key has been registered. - HasU2F bool `json:"has_u2f" valid:"required"` - - // True if a TOTP device has been registered. - HasTOTP bool `json:"has_totp" valid:"required"` -} - // signTOTPRequestBody model of the request body received by TOTP authentication endpoint. type signTOTPRequestBody struct { Token string `json:"token" valid:"required"` diff --git a/internal/middlewares/identity_verification.go b/internal/middlewares/identity_verification.go index 47fa6ecaf..d4f33571e 100644 --- a/internal/middlewares/identity_verification.go +++ b/internal/middlewares/identity_verification.go @@ -8,6 +8,7 @@ import ( "github.com/golang-jwt/jwt/v4" + "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/templates" ) @@ -47,7 +48,9 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle return } - err = ctx.Providers.StorageProvider.SaveIdentityVerificationToken(ss) + err = ctx.Providers.StorageProvider.SaveIdentityVerification(ctx, models.IdentityVerification{ + Token: ss, + }) if err != nil { ctx.Error(err, messageOperationFailed) return @@ -128,7 +131,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c return } - found, err := ctx.Providers.StorageProvider.FindIdentityVerificationToken(finishBody.Token) + found, err := ctx.Providers.StorageProvider.FindIdentityVerification(ctx, finishBody.Token) if err != nil { ctx.Error(err, messageOperationFailed) @@ -185,7 +188,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c } // TODO(c.michaud): find a way to garbage collect unused tokens. - err = ctx.Providers.StorageProvider.RemoveIdentityVerificationToken(finishBody.Token) + err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, finishBody.Token) if err != nil { ctx.Error(err, messageOperationFailed) return diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index 395657f42..29423ae61 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -55,7 +55,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) { mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.StorageProviderMock.EXPECT(). - SaveIdentityVerificationToken(gomock.Any()). + SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(fmt.Errorf("cannot save")) args := newArgs(defaultRetriever) @@ -74,7 +74,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.StorageProviderMock.EXPECT(). - SaveIdentityVerificationToken(gomock.Any()). + SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) mock.NotifierMock.EXPECT(). @@ -96,7 +96,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.StorageProviderMock.EXPECT(). - SaveIdentityVerificationToken(gomock.Any()). + SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) args := newArgs(defaultRetriever) @@ -114,7 +114,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.StorageProviderMock.EXPECT(). - SaveIdentityVerificationToken(gomock.Any()). + SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) args := newArgs(defaultRetriever) @@ -132,7 +132,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.StorageProviderMock.EXPECT(). - SaveIdentityVerificationToken(gomock.Any()). + SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) mock.NotifierMock.EXPECT(). @@ -209,7 +209,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB( s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}") s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq("abc")). + FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")). Return(false, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -222,7 +222,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() { s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}") s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq("abc")). + FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")). Return(true, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -238,7 +238,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() { s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -253,7 +253,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() { s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -268,7 +268,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() { s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) args := newFinishArgs() @@ -285,11 +285,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerificationToken(gomock.Eq(token)). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(fmt.Errorf("cannot remove")) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) @@ -304,11 +304,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete( s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerificationToken(gomock.Eq(token)). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(true, nil) s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerificationToken(gomock.Eq(token)). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)). Return(nil) middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index b797f4f8a..82e3e85c1 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -28,11 +28,6 @@ type AutheliaCtx struct { Clock utils.Clock } -// ProviderWithStartupCheck represents a provider that has a startup check. -type ProviderWithStartupCheck interface { - StartupCheck(logger *logrus.Logger) (err error) -} - // Providers contain all provider provided to Authelia. type Providers struct { Authorizer *authorization.Authorizer diff --git a/internal/mocks/mock_authelia_ctx.go b/internal/mocks/mock_authelia_ctx.go index 40a4bb80c..e56265154 100644 --- a/internal/mocks/mock_authelia_ctx.go +++ b/internal/mocks/mock_authelia_ctx.go @@ -183,3 +183,11 @@ func (m *MockAutheliaCtx) GetResponseData(t *testing.T, data interface{}) { err := json.Unmarshal(m.Ctx.Response.Body(), &okResponse) require.NoError(t, err) } + +// GetResponseError retrieves an error response from the service. +func (m *MockAutheliaCtx) GetResponseError(t *testing.T) (errResponse middlewares.ErrorResponse) { + err := json.Unmarshal(m.Ctx.Response.Body(), &errResponse) + require.NoError(t, err) + + return errResponse +} diff --git a/internal/mocks/mock_notifier.go b/internal/mocks/mock_notifier.go index aebba9a08..05ca829fd 100644 --- a/internal/mocks/mock_notifier.go +++ b/internal/mocks/mock_notifier.go @@ -8,7 +8,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - "github.com/sirupsen/logrus" ) // MockNotifier is a mock of Notifier interface. @@ -49,15 +48,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go } // StartupCheck mocks base method. -func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error { +func (m *MockNotifier) StartupCheck() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartupCheck", arg0) + ret := m.ctrl.Call(m, "StartupCheck") ret0, _ := ret[0].(error) return ret0 } // StartupCheck indicates an expected call of StartupCheck. -func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call { +func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck)) } diff --git a/internal/mocks/mock_user_provider.go b/internal/mocks/mock_user_provider.go index c05d223e8..25d0e2aec 100644 --- a/internal/mocks/mock_user_provider.go +++ b/internal/mocks/mock_user_provider.go @@ -5,12 +5,11 @@ package mocks import ( - "reflect" + reflect "reflect" - "github.com/golang/mock/gomock" - "github.com/sirupsen/logrus" + gomock "github.com/golang/mock/gomock" - "github.com/authelia/authelia/v4/internal/authentication" + authentication "github.com/authelia/authelia/v4/internal/authentication" ) // MockUserProvider is a mock of UserProvider interface. @@ -66,7 +65,21 @@ func (mr *MockUserProviderMockRecorder) GetDetails(arg0 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDetails", reflect.TypeOf((*MockUserProvider)(nil).GetDetails), arg0) } -// UpdatePassword mocks base method +// StartupCheck mocks base method. +func (m *MockUserProvider) StartupCheck() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartupCheck") + ret0, _ := ret[0].(error) + return ret0 +} + +// StartupCheck indicates an expected call of StartupCheck. +func (mr *MockUserProviderMockRecorder) StartupCheck() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck)) +} + +// UpdatePassword mocks base method. func (m *MockUserProvider) UpdatePassword(arg0, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1) @@ -79,17 +92,3 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1) } - -// StartupCheck mocks base method. -func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "StartupCheck", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// StartupCheck indicates an expected call of StartupCheck. -func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0) -} diff --git a/internal/models/model_authentication_attempt.go b/internal/models/model_authentication_attempt.go new file mode 100644 index 000000000..6720b524c --- /dev/null +++ b/internal/models/model_authentication_attempt.go @@ -0,0 +1,17 @@ +package models + +import ( + "time" +) + +// AuthenticationAttempt represents an authentication attempt row in the database. +type AuthenticationAttempt struct { + ID int `db:"id"` + Time time.Time `db:"time"` + Successful bool `db:"successful"` + Username string `db:"username"` + Type string `db:"auth_type"` + RemoteIP IPAddress `db:"remote_ip"` + RequestURI string `db:"request_uri"` + RequestMethod string `db:"request_method"` +} diff --git a/internal/models/model_identity_verification.go b/internal/models/model_identity_verification.go new file mode 100644 index 000000000..873ba92a1 --- /dev/null +++ b/internal/models/model_identity_verification.go @@ -0,0 +1,12 @@ +package models + +import ( + "time" +) + +// IdentityVerification represents an identity verification row in the database. +type IdentityVerification struct { + ID int `db:"id"` + Created time.Time `db:"created"` + Token string `db:"token"` +} diff --git a/internal/models/model_migration.go b/internal/models/model_migration.go new file mode 100644 index 000000000..54fafff3d --- /dev/null +++ b/internal/models/model_migration.go @@ -0,0 +1,14 @@ +package models + +import ( + "time" +) + +// Migration represents a migration row in the database. +type Migration struct { + ID int `db:"id"` + Applied time.Time `db:"applied"` + Before int `db:"version_before"` + After int `db:"version_after"` + Version string `db:"application_version"` +} diff --git a/internal/models/model_totp_configuration.go b/internal/models/model_totp_configuration.go new file mode 100644 index 000000000..b274689dc --- /dev/null +++ b/internal/models/model_totp_configuration.go @@ -0,0 +1,11 @@ +package models + +// TOTPConfiguration represents a users TOTP configuration row in the database. +type TOTPConfiguration struct { + ID int `db:"id"` + Username string `db:"username"` + Algorithm string `db:"algorithm"` + Digits int `db:"digits"` + Period uint64 `db:"totp_period"` + Secret string `db:"secret"` +} diff --git a/internal/models/model_u2f_device.go b/internal/models/model_u2f_device.go new file mode 100644 index 000000000..37e36b2c9 --- /dev/null +++ b/internal/models/model_u2f_device.go @@ -0,0 +1,10 @@ +package models + +// U2FDevice represents a users U2F device row in the database. +type U2FDevice struct { + ID int `db:"id"` + Username string `db:"username"` + Description string `db:"description"` + KeyHandle []byte `db:"key_handle"` + PublicKey []byte `db:"public_key"` +} diff --git a/internal/models/model_userinfo.go b/internal/models/model_userinfo.go new file mode 100644 index 000000000..da0b1f9fb --- /dev/null +++ b/internal/models/model_userinfo.go @@ -0,0 +1,16 @@ +package models + +// UserInfo represents the user information required by the web UI. +type UserInfo struct { + // The users display name. + DisplayName string `db:"-" json:"display_name"` + + // The preferred 2FA method. + Method string `db:"second_factor_method" json:"method" valid:"required"` + + // True if a security key has been registered. + HasU2F bool `db:"has_u2f" json:"has_u2f" valid:"required"` + + // True if a TOTP device has been registered. + HasTOTP bool `db:"has_totp" json:"has_totp" valid:"required"` +} diff --git a/internal/models/type_ipaddress.go b/internal/models/type_ipaddress.go new file mode 100644 index 000000000..8078748bd --- /dev/null +++ b/internal/models/type_ipaddress.go @@ -0,0 +1,42 @@ +package models + +import ( + "database/sql/driver" + "fmt" + "net" +) + +// IPAddress is a type specific for storage of a net.IP in the database. +type IPAddress struct { + *net.IP +} + +// Value is the IPAddress implementation of the databases/sql driver.Valuer. +func (ip IPAddress) Value() (value driver.Value, err error) { + if ip.IP == nil { + return driver.Value(nil), nil + } + + return driver.Value(ip.IP.String()), nil +} + +// Scan is the IPAddress implementation of the sql.Scanner. +func (ip *IPAddress) Scan(src interface{}) (err error) { + if src == nil { + ip.IP = nil + return nil + } + + var value string + + switch v := src.(type) { + case string: + value = v + default: + return fmt.Errorf("invalid type %T for IPAddress %v", src, src) + } + + *ip.IP = net.ParseIP(value) + + return nil +} diff --git a/internal/models/type_startup_check.go b/internal/models/type_startup_check.go new file mode 100644 index 000000000..76ca09ff4 --- /dev/null +++ b/internal/models/type_startup_check.go @@ -0,0 +1,6 @@ +package models + +// StartupCheck represents a provider that has a startup check. +type StartupCheck interface { + StartupCheck() (err error) +} diff --git a/internal/models/types.go b/internal/models/types.go deleted file mode 100644 index 7a3e03739..000000000 --- a/internal/models/types.go +++ /dev/null @@ -1,13 +0,0 @@ -package models - -import "time" - -// AuthenticationAttempt represent an authentication attempt. -type AuthenticationAttempt struct { - // The user who tried to authenticate. - Username string - // Successful true if the attempt was successful. - Successful bool - // The time of the attempt. - Time time.Time -} diff --git a/internal/notification/file_notifier.go b/internal/notification/file_notifier.go index f0de2174e..db287a6de 100644 --- a/internal/notification/file_notifier.go +++ b/internal/notification/file_notifier.go @@ -7,8 +7,6 @@ import ( "path/filepath" "time" - "github.com/sirupsen/logrus" - "github.com/authelia/authelia/v4/internal/configuration/schema" ) @@ -25,7 +23,7 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File } // StartupCheck implements the startup check provider interface. -func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) { +func (n *FileNotifier) StartupCheck() (err error) { dir := filepath.Dir(n.path) if _, err := os.Stat(dir); err != nil { if os.IsNotExist(err) { diff --git a/internal/notification/notifier.go b/internal/notification/notifier.go index 74210d494..925dd1415 100644 --- a/internal/notification/notifier.go +++ b/internal/notification/notifier.go @@ -1,11 +1,12 @@ package notification import ( - "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/models" ) // Notifier interface for sending the identity verification link. type Notifier interface { + models.StartupCheck + Send(recipient, subject, body, htmlBody string) (err error) - StartupCheck(logger *logrus.Logger) (err error) } diff --git a/internal/notification/smtp_notifier.go b/internal/notification/smtp_notifier.go index 28bb47849..332712704 100644 --- a/internal/notification/smtp_notifier.go +++ b/internal/notification/smtp_notifier.go @@ -10,8 +10,6 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" - "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/utils" @@ -223,7 +221,7 @@ func (n *SMTPNotifier) cleanup() { } // StartupCheck implements the startup check provider interface. -func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) { +func (n *SMTPNotifier) StartupCheck() (err error) { if err := n.dial(); err != nil { return err } diff --git a/internal/ntp/ntp.go b/internal/ntp/ntp.go index e30c4ca6f..36b5370be 100644 --- a/internal/ntp/ntp.go +++ b/internal/ntp/ntp.go @@ -6,22 +6,24 @@ import ( "net" "time" - "github.com/sirupsen/logrus" - "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/utils" ) // NewProvider instantiate a ntp provider given a configuration. func NewProvider(config *schema.NTPConfiguration) *Provider { - return &Provider{config} + return &Provider{ + config: config, + log: logging.Logger(), + } } // StartupCheck implements the startup check provider interface. -func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) { +func (p *Provider) StartupCheck() (err error) { conn, err := net.Dial("udp", p.config.Address) if err != nil { - logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) + p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) return nil } @@ -29,7 +31,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) { defer conn.Close() if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) + p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err) return nil } @@ -42,7 +44,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) { req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)} if err := binary.Write(conn, binary.BigEndian, req); err != nil { - logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err) + p.log.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err) return nil } @@ -52,7 +54,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) { resp := &ntpPacket{} if err := binary.Read(conn, binary.BigEndian, resp); err != nil { - logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err) + p.log.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err) return nil } diff --git a/internal/ntp/ntp_test.go b/internal/ntp/ntp_test.go index 95e230b8e..756568bdd 100644 --- a/internal/ntp/ntp_test.go +++ b/internal/ntp/ntp_test.go @@ -7,7 +7,6 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/validator" - "github.com/authelia/authelia/v4/internal/logging" ) func TestShouldCheckNTP(t *testing.T) { @@ -22,5 +21,5 @@ func TestShouldCheckNTP(t *testing.T) { ntp := NewProvider(&config) - assert.NoError(t, ntp.StartupCheck(logging.Logger())) + assert.NoError(t, ntp.StartupCheck()) } diff --git a/internal/ntp/types.go b/internal/ntp/types.go index 7aa69dad7..d47e732a6 100644 --- a/internal/ntp/types.go +++ b/internal/ntp/types.go @@ -1,12 +1,15 @@ package ntp import ( + "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/configuration/schema" ) // Provider type is the NTP provider. type Provider struct { config *schema.NTPConfiguration + log *logrus.Logger } type ntpVersion int diff --git a/internal/oidc/provider.go b/internal/oidc/provider.go index 69b076a8f..f12ce576c 100644 --- a/internal/oidc/provider.go +++ b/internal/oidc/provider.go @@ -20,10 +20,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) return provider, nil } - provider.Store, err = NewOpenIDConnectStore(configuration) - if err != nil { - return provider, err - } + provider.Store = NewOpenIDConnectStore(configuration) composeConfiguration := &compose.Config{ AccessTokenLifespan: configuration.AccessTokenLifespan, diff --git a/internal/oidc/store.go b/internal/oidc/store.go index 4ab335183..5a950a69e 100644 --- a/internal/oidc/store.go +++ b/internal/oidc/store.go @@ -14,7 +14,7 @@ import ( ) // NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration. -func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore, err error) { +func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) { logger := logging.Logger() store = &OpenIDConnectStore{ @@ -39,7 +39,7 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st store.clients[client.ID] = NewClient(client) } - return store, nil + return store } // GetClientPolicy retrieves the policy from the client with the matching provided id. diff --git a/internal/oidc/store_test.go b/internal/oidc/store_test.go index d1bad7d4e..d69df0a83 100644 --- a/internal/oidc/store_test.go +++ b/internal/oidc/store_test.go @@ -12,7 +12,7 @@ import ( ) func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) { - s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ + s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: exampleIssuerPrivateKey, Clients: []schema.OpenIDConnectClientConfiguration{ { @@ -32,8 +32,6 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) { }, }) - require.NoError(t, err) - policyOne := s.GetClientPolicy("myclient") assert.Equal(t, authorization.OneFactor, policyOne) @@ -45,7 +43,7 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) { } func TestOpenIDConnectStore_GetInternalClient(t *testing.T) { - s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ + s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: exampleIssuerPrivateKey, Clients: []schema.OpenIDConnectClientConfiguration{ { @@ -58,8 +56,6 @@ func TestOpenIDConnectStore_GetInternalClient(t *testing.T) { }, }) - require.NoError(t, err) - client, err := s.GetClient(context.Background(), "myinvalidclient") assert.EqualError(t, err, "not_found") assert.Nil(t, client) @@ -78,13 +74,12 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) { Scopes: []string{"openid", "profile"}, Secret: "mysecret", } - s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ + + s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: exampleIssuerPrivateKey, Clients: []schema.OpenIDConnectClientConfiguration{c1}, }) - require.NoError(t, err) - client, err := s.GetInternalClient(c1.ID) require.NoError(t, err) require.NotNil(t, client) @@ -106,20 +101,19 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) { Scopes: []string{"openid", "profile"}, Secret: "mysecret", } - s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ + + s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: exampleIssuerPrivateKey, Clients: []schema.OpenIDConnectClientConfiguration{c1}, }) - require.NoError(t, err) - client, err := s.GetInternalClient("another-client") assert.Nil(t, client) assert.EqualError(t, err, "not_found") } func TestOpenIDConnectStore_IsValidClientID(t *testing.T) { - s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ + s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: exampleIssuerPrivateKey, Clients: []schema.OpenIDConnectClientConfiguration{ { @@ -132,8 +126,6 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) { }, }) - require.NoError(t, err) - validClient := s.IsValidClientID("myclient") invalidClient := s.IsValidClientID("myinvalidclient") diff --git a/internal/regulation/regulator.go b/internal/regulation/regulator.go index 66a40dbce..f4b89870a 100644 --- a/internal/regulation/regulator.go +++ b/internal/regulation/regulator.go @@ -1,6 +1,7 @@ package regulation import ( + "context" "fmt" "time" @@ -11,7 +12,7 @@ import ( ) // NewRegulator create a regulator instance. -func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.Provider, clock utils.Clock) *Regulator { +func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { regulator := &Regulator{storageProvider: provider} regulator.clock = clock @@ -40,30 +41,25 @@ func NewRegulator(configuration *schema.RegulationConfiguration, provider storag return regulator } -// Mark mark an authentication attempt. +// Mark an authentication attempt. // We split Mark and Regulate in order to avoid timing attacks. -func (r *Regulator) Mark(username string, successful bool) error { - return r.storageProvider.AppendAuthenticationLog(models.AuthenticationAttempt{ +func (r *Regulator) Mark(ctx context.Context, username string, successful bool) error { + return r.storageProvider.AppendAuthenticationLog(ctx, models.AuthenticationAttempt{ Username: username, Successful: successful, Time: r.clock.Now(), }) } -// Regulate regulate the authentication attempts for a given user. -// This method returns ErrUserIsBanned if the user is banned along with the time until when -// the user is banned. -func (r *Regulator) Regulate(username string) (time.Time, error) { +// Regulate the authentication attempts for a given user. +// This method returns ErrUserIsBanned if the user is banned along with the time until when the user is banned. +func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, error) { // If there is regulation configuration, no regulation applies. if !r.enabled { return time.Time{}, nil } - now := r.clock.Now() - - // TODO(c.michaud): make sure FindTime < BanTime. - attempts, err := r.storageProvider.LoadLatestAuthenticationLogs(username, now.Add(-r.banTime)) - + attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0) if err != nil { return time.Time{}, nil } diff --git a/internal/regulation/regulator_test.go b/internal/regulation/regulator_test.go index 8bb4cc848..5201ee381 100644 --- a/internal/regulation/regulator_test.go +++ b/internal/regulation/regulator_test.go @@ -1,6 +1,7 @@ package regulation_test import ( + "context" "testing" "time" @@ -18,6 +19,7 @@ import ( type RegulatorSuite struct { suite.Suite + ctx context.Context ctrl *gomock.Controller storageMock *storage.MockProvider configuration schema.RegulationConfiguration @@ -27,6 +29,7 @@ type RegulatorSuite struct { func (s *RegulatorSuite) SetupTest() { s.ctrl = gomock.NewController(s.T()) s.storageMock = storage.NewMockProvider(s.ctrl) + s.ctx = context.Background() s.configuration = schema.RegulationConfiguration{ MaxRetries: 3, @@ -50,12 +53,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) } @@ -81,12 +84,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) } @@ -117,12 +120,12 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -150,12 +153,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } @@ -174,12 +177,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) } @@ -206,12 +209,12 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) } @@ -242,12 +245,12 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) } @@ -277,7 +280,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } s.storageMock.EXPECT(). - LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()). + LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) // Check Disabled Functionality @@ -288,7 +291,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock) - _, err := regulator.Regulate("john") + _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) // Check Enabled Functionality @@ -299,6 +302,6 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { } regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock) - _, err = regulator.Regulate("john") + _, err = regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } diff --git a/internal/regulation/types.go b/internal/regulation/types.go index 3877c6ea0..21d0612aa 100644 --- a/internal/regulation/types.go +++ b/internal/regulation/types.go @@ -18,7 +18,7 @@ type Regulator struct { // If a user has been banned, this duration is the timelapse during which the user is banned. banTime time.Duration - storageProvider storage.Provider + storageProvider storage.RegulatorProvider clock utils.Clock } diff --git a/internal/storage/const.go b/internal/storage/const.go index af9b4de63..c0cbd73d8 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -1,39 +1,52 @@ package storage import ( - "fmt" + "regexp" ) -const storageSchemaCurrentVersion = SchemaVersion(1) -const storageSchemaUpgradeMessage = "Storage schema upgraded to v" -const storageSchemaUpgradeErrorText = "storage schema upgrade failed at v" +const ( + tableUserPreferences = "user_preferences" + tableIdentityVerification = "identity_verification_tokens" + tableTOTPConfigurations = "totp_configurations" + tableU2FDevices = "u2f_devices" + tableDUODevices = "duo_devices" + tableAuthenticationLogs = "authentication_logs" + tableMigrations = "migrations" -// Keep table names in lower case because some DB does not support upper case. -const userPreferencesTableName = "user_preferences" -const identityVerificationTokensTableName = "identity_verification_tokens" -const totpSecretsTableName = "totp_secrets" -const u2fDeviceHandlesTableName = "u2f_devices" -const authenticationLogsTableName = "authentication_logs" -const configTableName = "config" + tablePrefixBackup = "_bkp_" +) -// sqlUpgradeCreateTableStatements is a map of the schema version number, plus a map of the table name and the statement used to create it. -// The statement is fmt.Sprintf'd with the table name as the first argument. -var sqlUpgradeCreateTableStatements = map[SchemaVersion]map[string]string{ - SchemaVersion(1): { - userPreferencesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, second_factor_method VARCHAR(11))", - identityVerificationTokensTableName: "CREATE TABLE %s (token VARCHAR(512))", - totpSecretsTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, secret VARCHAR(64))", - u2fDeviceHandlesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)", - authenticationLogsTableName: "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER)", - configTableName: "CREATE TABLE %s (category VARCHAR(32) NOT NULL, key_name VARCHAR(32) NOT NULL, value TEXT, PRIMARY KEY (category, key_name))", - }, -} +// WARNING: Do not change/remove these consts. They are used for Pre1 migrations. +const ( + tablePre1TOTPSecrets = "totp_secrets" + tablePre1Config = "config" + tablePre1IdentityVerificationTokens = "identity_verification_tokens" + tableAlphaAuthenticationLogs = "AuthenticationLogs" + tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens" + tableAlphaPreferences = "Preferences" + tableAlphaPreferencesTableName = "PreferencesTableName" + tableAlphaSecondFactorPreferences = "SecondFactorPreferences" + tableAlphaTOTPSecrets = "TOTPSecrets" + tableAlphaU2FDeviceHandles = "U2FDeviceHandles" +) -// sqlUpgradesCreateTableIndexesStatements is a map of t he schema version number, plus a slice of statements to create all of the indexes. -var sqlUpgradesCreateTableIndexesStatements = map[SchemaVersion][]string{ - SchemaVersion(1): { - fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s (username, time)", authenticationLogsTableName), - }, -} +const ( + providerAll = "all" + providerMySQL = "mysql" + providerPostgres = "postgres" + providerSQLite = "sqlite" +) -const unitTestUser = "john" +const ( + // This is the latest schema version for the purpose of tests. + testLatestVersion = 1 +) + +const ( + // SchemaLatest represents the value expected for a "migrate to latest" migration. It's the maximum 32bit signed integer. + SchemaLatest = 2147483647 +) + +var ( + reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`) +) diff --git a/internal/storage/errors.go b/internal/storage/errors.go index 7cf6e8c0c..d84a2bd66 100644 --- a/internal/storage/errors.go +++ b/internal/storage/errors.go @@ -1,6 +1,8 @@ package storage -import "errors" +import ( + "errors" +) var ( // ErrNoU2FDeviceHandle error thrown when no U2F device handle has been found in DB. @@ -8,4 +10,35 @@ var ( // ErrNoTOTPSecret error thrown when no TOTP secret has been found in DB. ErrNoTOTPSecret = errors.New("no TOTP secret registered") + + // ErrNoAvailableMigrations is returned when no available migrations can be found. + ErrNoAvailableMigrations = errors.New("no available migrations") + + // ErrSchemaAlreadyUpToDate is returned when the schema is already up to date. + ErrSchemaAlreadyUpToDate = errors.New("schema already up to date") + + // ErrNoMigrationsFound is returned when no migrations were found. + ErrNoMigrationsFound = errors.New("no schema migrations found") +) + +// Error formats for the storage provider. +const ( + ErrFmtMigrateUpTargetLessThanCurrent = "schema up migration target version %d is less then the current version %d" + ErrFmtMigrateUpTargetGreaterThanLatest = "schema up migration target version %d is greater then the latest version %d which indicates it doesn't exist" + ErrFmtMigrateDownTargetGreaterThanCurrent = "schema down migration target version %d is greater than the current version %d" + ErrFmtMigrateDownTargetLessThanMinimum = "schema down migration target version %d is less than the minimum version" + ErrFmtMigrateAlreadyOnTargetVersion = "schema migration target version %d is the same current version %d" +) + +const ( + errFmtFailedMigration = "schema migration %d (%s) failed: %w" + errFmtFailedMigrationPre1 = "schema migration pre1 failed: %w" + errFmtSchemaCurrentGreaterThanLatestKnown = "current schema version is greater than the latest known schema " + + "version, you must downgrade to schema version %d before you can use this version of Authelia" +) + +const ( + logFmtMigrationFromTo = "Storage schema migration from %s to %s is being attempted" + logFmtMigrationComplete = "Storage schema migration from %s to %s is complete" + logFmtErrClosingConn = "Error occurred closing SQL connection: %v" ) diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go new file mode 100644 index 000000000..4a8427abe --- /dev/null +++ b/internal/storage/migrations.go @@ -0,0 +1,204 @@ +package storage + +import ( + "embed" + "errors" + "fmt" + "sort" + "strconv" + "strings" +) + +//go:embed migrations/* +var migrationsFS embed.FS + +func latestMigrationVersion(providerName string) (version int, err error) { + entries, err := migrationsFS.ReadDir("migrations") + if err != nil { + return -1, err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + m, err := scanMigration(entry.Name()) + if err != nil { + return -1, err + } + + if m.Provider != providerName { + continue + } + + if !m.Up { + continue + } + + if m.Version > version { + version = m.Version + } + } + + return version, nil +} + +func loadMigration(providerName string, version int, up bool) (migration *SchemaMigration, err error) { + entries, err := migrationsFS.ReadDir("migrations") + if err != nil { + return nil, err + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + m, err := scanMigration(entry.Name()) + if err != nil { + return nil, err + } + + migration = &m + + if up != migration.Up { + continue + } + + if migration.Provider != providerAll && migration.Provider != providerName { + continue + } + + if version != migration.Version { + continue + } + + return migration, nil + } + + return nil, errors.New("migration not found") +} + +// loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and +// target versions. If the target version is -1 this indicates the latest version. If the target version is 0 +// this indicates the database zero state. +func loadMigrations(providerName string, prior, target int) (migrations []SchemaMigration, err error) { + if prior == target && (prior != -1 || target != -1) { + return nil, errors.New("cannot migrate to the same version as prior") + } + + entries, err := migrationsFS.ReadDir("migrations") + if err != nil { + return nil, err + } + + up := prior < target + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + migration, err := scanMigration(entry.Name()) + if err != nil { + return nil, err + } + + if skipMigration(providerName, up, target, prior, &migration) { + continue + } + + migrations = append(migrations, migration) + } + + if up { + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Version < migrations[j].Version + }) + } else { + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].Version > migrations[j].Version + }) + } + + return migrations, nil +} + +func skipMigration(providerName string, up bool, target, prior int, migration *SchemaMigration) (skip bool) { + if migration.Provider != providerAll && migration.Provider != providerName { + // Skip if migration.Provider is not a match. + return true + } + + if up { + if !migration.Up { + // Skip if we wanted an Up migration but it isn't an Up migration. + return true + } + + if target != -1 && (migration.Version > target || migration.Version <= prior) { + // Skip if the migration version is greater than the target or less than or equal to the previous version. + return true + } + } else { + if migration.Up { + // Skip if we didn't want an Up migration but it is an Up migration. + return true + } + + if migration.Version == 1 && target == -1 { + // Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data + // preventing a successful migration. + return true + } + + if migration.Version <= target || migration.Version > prior { + // Skip the migration if we want to go down and the migration version is less than or equal to the target + // or greater than the previous version. + return true + } + } + + return false +} + +func scanMigration(m string) (migration SchemaMigration, err error) { + result := reMigration.FindStringSubmatch(m) + + if result == nil || len(result) != 5 { + return SchemaMigration{}, errors.New("invalid migration: could not parse the format") + } + + migration = SchemaMigration{ + Name: strings.ReplaceAll(result[2], "_", " "), + Provider: result[3], + } + + data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)) + if err != nil { + return SchemaMigration{}, err + } + + migration.Query = string(data) + + switch result[4] { + case "up": + migration.Up = true + case "down": + migration.Up = false + default: + return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4]) + } + + migration.Version, _ = strconv.Atoi(result[1]) + + switch migration.Provider { + case providerAll, providerSQLite, providerMySQL, providerPostgres: + break + default: + return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3]) + } + + return migration, nil +} diff --git a/internal/storage/migrations/V0001.Initial_Schema.all.down.sql b/internal/storage/migrations/V0001.Initial_Schema.all.down.sql new file mode 100644 index 000000000..1fef4a536 --- /dev/null +++ b/internal/storage/migrations/V0001.Initial_Schema.all.down.sql @@ -0,0 +1,6 @@ +DROP TABLE IF EXISTS authentication_logs; +DROP TABLE IF EXISTS identity_verification_tokens; +DROP TABLE IF EXISTS totp_configurations; +DROP TABLE IF EXISTS u2f_devices; +DROP TABLE IF EXISTS user_preferences; +DROP TABLE IF EXISTS migrations; diff --git a/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql b/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql new file mode 100644 index 000000000..858959b5b --- /dev/null +++ b/internal/storage/migrations/V0001.Initial_Schema.mysql.up.sql @@ -0,0 +1,55 @@ +CREATE TABLE IF NOT EXISTS authentication_logs ( + id INTEGER AUTO_INCREMENT, + time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + successful BOOL NOT NULL, + username VARCHAR(100) NOT NULL, + PRIMARY KEY (id) +); + +CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username); + +CREATE TABLE IF NOT EXISTS identity_verification_tokens ( + id INTEGER AUTO_INCREMENT, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + token VARCHAR(512), + PRIMARY KEY (id), + UNIQUE KEY (token) +); + +CREATE TABLE IF NOT EXISTS totp_configurations ( + id INTEGER AUTO_INCREMENT, + username VARCHAR(100) NOT NULL, + algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1', + digits INTEGER NOT NULL DEFAULT 6, + totp_period INTEGER NOT NULL DEFAULT 30, + secret VARCHAR(64) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY (username) +); + +CREATE TABLE IF NOT EXISTS u2f_devices ( + id INTEGER AUTO_INCREMENT, + username VARCHAR(100) NOT NULL, + description VARCHAR(30) NOT NULL DEFAULT 'Primary', + key_handle BLOB NOT NULL, + public_key BLOB NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY (username, description) +); + +CREATE TABLE IF NOT EXISTS user_preferences ( + id INTEGER AUTO_INCREMENT, + username VARCHAR(100) NOT NULL, + second_factor_method VARCHAR(11) NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY (username) +); + +CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER AUTO_INCREMENT, + applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + version_before INTEGER NULL DEFAULT NULL, + version_after INTEGER NOT NULL, + application_version VARCHAR(128) NOT NULL, + PRIMARY KEY (id) +); diff --git a/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql b/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql new file mode 100644 index 000000000..ade06ccfb --- /dev/null +++ b/internal/storage/migrations/V0001.Initial_Schema.postgres.up.sql @@ -0,0 +1,55 @@ +CREATE TABLE IF NOT EXISTS authentication_logs ( + id SERIAL, + time TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + successful BOOLEAN NOT NULL, + username VARCHAR(100) NOT NULL, + PRIMARY KEY (id) +); + +CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username); + +CREATE TABLE IF NOT EXISTS identity_verification_tokens ( + id SERIAL, + created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + token VARCHAR(512), + PRIMARY KEY (id), + UNIQUE (token) +); + +CREATE TABLE IF NOT EXISTS totp_configurations ( + id SERIAL, + username VARCHAR(100) NOT NULL, + algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1', + digits INTEGER NOT NULL DEFAULT 6, + totp_period INTEGER NOT NULL DEFAULT 30, + secret VARCHAR(64) NOT NULL, + PRIMARY KEY (id), + UNIQUE (username) +); + +CREATE TABLE IF NOT EXISTS u2f_devices ( + id SERIAL, + username VARCHAR(100) NOT NULL, + description VARCHAR(30) NOT NULL DEFAULT 'Primary', + key_handle BYTEA NOT NULL, + public_key BYTEA NOT NULL, + PRIMARY KEY (id), + UNIQUE (username, description) +); + +CREATE TABLE IF NOT EXISTS user_preferences ( + id SERIAL, + username VARCHAR(100) NOT NULL, + second_factor_method VARCHAR(11) NOT NULL, + PRIMARY KEY (id), + UNIQUE (username) +); + +CREATE TABLE IF NOT EXISTS migrations ( + id SERIAL, + applied TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + version_before INTEGER NULL DEFAULT NULL, + version_after INTEGER NOT NULL, + application_version VARCHAR(128) NOT NULL, + PRIMARY KEY (id) +); diff --git a/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql b/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql new file mode 100644 index 000000000..139d45853 --- /dev/null +++ b/internal/storage/migrations/V0001.Initial_Schema.sqlite.up.sql @@ -0,0 +1,54 @@ +CREATE TABLE IF NOT EXISTS authentication_logs ( + id INTEGER, + time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + successful BOOLEAN NOT NULL, + username VARCHAR(100) NOT NULL, + PRIMARY KEY (id) +); + +CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username); + +CREATE TABLE IF NOT EXISTS identity_verification_tokens ( + id INTEGER, + created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + token VARCHAR(512), + PRIMARY KEY (id), + UNIQUE (token) +); + +CREATE TABLE IF NOT EXISTS totp_configurations ( + id INTEGER, + username VARCHAR(100) NOT NULL, + algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1', + digits INTEGER(1) NOT NULL DEFAULT 6, + totp_period INTEGER NOT NULL DEFAULT 30, + secret VARCHAR(64) NOT NULL, + PRIMARY KEY (id), + UNIQUE (username) +); + +CREATE TABLE IF NOT EXISTS u2f_devices ( + id INTEGER, + username VARCHAR(100) NOT NULL, + description VARCHAR(30) NOT NULL DEFAULT 'Primary', + key_handle BLOB NOT NULL, + public_key BLOB NOT NULL, + PRIMARY KEY (id), + UNIQUE (username, description) +); + +CREATE TABLE IF NOT EXISTS user_preferences ( + id INTEGER, + username VARCHAR(100) UNIQUE NOT NULL, + second_factor_method VARCHAR(11) NOT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER, + applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + version_before INTEGER NULL DEFAULT NULL, + version_after INTEGER NOT NULL, + application_version VARCHAR(128) NOT NULL, + PRIMARY KEY (id) +); diff --git a/internal/storage/migrations_test.go b/internal/storage/migrations_test.go new file mode 100644 index 000000000..38fa84f42 --- /dev/null +++ b/internal/storage/migrations_test.go @@ -0,0 +1,154 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShouldObtainCorrectUpMigrations(t *testing.T) { + ver, err := latestMigrationVersion(providerSQLite) + require.NoError(t, err) + + assert.Equal(t, testLatestVersion, ver) + + migrations, err := loadMigrations(providerSQLite, 0, ver) + require.NoError(t, err) + + assert.Len(t, migrations, ver) + + for i := 0; i < len(migrations); i++ { + assert.Equal(t, i+1, migrations[i].Version) + } +} + +func TestShouldObtainCorrectDownMigrations(t *testing.T) { + ver, err := latestMigrationVersion(providerSQLite) + require.NoError(t, err) + + assert.Equal(t, testLatestVersion, ver) + + migrations, err := loadMigrations(providerSQLite, ver, 0) + require.NoError(t, err) + + assert.Len(t, migrations, ver) + + for i := 0; i < len(migrations); i++ { + assert.Equal(t, ver-i, migrations[i].Version) + } +} + +func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) { + migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousUp := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.True(t, migration.Up) + + if i != 0 { + for _, v := range previousUp { + assert.NotEqual(t, v, migration.Version) + } + } + + previousUp = append(previousUp, migration.Version) + } + + migrations, err = loadMigrations(providerPostgres, SchemaLatest, 0) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousDown := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.False(t, migration.Up) + + if i != 0 { + for _, v := range previousDown { + assert.NotEqual(t, v, migration.Version) + } + } + + previousDown = append(previousDown, migration.Version) + } +} + +func TestMigrationsShouldNotBeDuplicatedMySQL(t *testing.T) { + migrations, err := loadMigrations(providerMySQL, 0, SchemaLatest) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousUp := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.True(t, migration.Up) + + if i != 0 { + for _, v := range previousUp { + assert.NotEqual(t, v, migration.Version) + } + } + + previousUp = append(previousUp, migration.Version) + } + + migrations, err = loadMigrations(providerMySQL, SchemaLatest, 0) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousDown := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.False(t, migration.Up) + + if i != 0 { + for _, v := range previousDown { + assert.NotEqual(t, v, migration.Version) + } + } + + previousDown = append(previousDown, migration.Version) + } +} + +func TestMigrationsShouldNotBeDuplicatedSQLite(t *testing.T) { + migrations, err := loadMigrations(providerSQLite, 0, SchemaLatest) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousUp := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.True(t, migration.Up) + + if i != 0 { + for _, v := range previousUp { + assert.NotEqual(t, v, migration.Version) + } + } + + previousUp = append(previousUp, migration.Version) + } + + migrations, err = loadMigrations(providerSQLite, SchemaLatest, 0) + require.NoError(t, err) + require.NotEqual(t, 0, len(migrations)) + + previousDown := make([]int, len(migrations)) + + for i, migration := range migrations { + assert.False(t, migration.Up) + + if i != 0 { + for _, v := range previousDown { + assert.NotEqual(t, v, migration.Version) + } + } + + previousDown = append(previousDown, migration.Version) + } +} diff --git a/internal/storage/mysql_provider.go b/internal/storage/mysql_provider.go deleted file mode 100644 index 29e9a778d..000000000 --- a/internal/storage/mysql_provider.go +++ /dev/null @@ -1,85 +0,0 @@ -package storage - -import ( - "database/sql" - "fmt" - "time" - - _ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string. - - "github.com/authelia/authelia/v4/internal/configuration/schema" -) - -// MySQLProvider is a MySQL provider. -type MySQLProvider struct { - SQLProvider -} - -// NewMySQLProvider a MySQL provider. -func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProvider { - provider := MySQLProvider{ - SQLProvider{ - name: "mysql", - - sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements, - - sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName), - sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName), - - sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName), - sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName), - sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName), - - sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName), - sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName), - sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName), - - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName), - - sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName), - sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName), - - sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema=database()", - - sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName), - sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName), - }, - } - - provider.sqlUpgradesCreateTableStatements[SchemaVersion(1)][authenticationLogsTableName] = "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER, INDEX usr_time_idx (username, time))" - - connectionString := configuration.Username - - if configuration.Password != "" { - connectionString += fmt.Sprintf(":%s", configuration.Password) - } - - if connectionString != "" { - connectionString += "@" - } - - address := configuration.Host - if configuration.Port > 0 { - address += fmt.Sprintf(":%d", configuration.Port) - } - - connectionString += fmt.Sprintf("tcp(%s)", address) - if configuration.Database != "" { - connectionString += fmt.Sprintf("/%s", configuration.Database) - } - - connectionString += "?" - connectionString += fmt.Sprintf("timeout=%ds", int32(configuration.Timeout/time.Second)) - - db, err := sql.Open("mysql", connectionString) - if err != nil { - provider.log.Fatalf("Unable to connect to SQL database: %v", err) - } - - if err := provider.initialize(db); err != nil { - provider.log.Fatalf("Unable to initialize SQL database: %v", err) - } - - return &provider -} diff --git a/internal/storage/postgres_provider.go b/internal/storage/postgres_provider.go deleted file mode 100644 index c2783d082..000000000 --- a/internal/storage/postgres_provider.go +++ /dev/null @@ -1,90 +0,0 @@ -package storage - -import ( - "database/sql" - "fmt" - "strings" - "time" - - _ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string. - - "github.com/authelia/authelia/v4/internal/configuration/schema" -) - -// PostgreSQLProvider is a PostgreSQL provider. -type PostgreSQLProvider struct { - SQLProvider -} - -// NewPostgreSQLProvider a PostgreSQL provider. -func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration) *PostgreSQLProvider { - provider := PostgreSQLProvider{ - SQLProvider{ - name: "postgres", - - sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements, - sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements, - - sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=$1", userPreferencesTableName), - sqlUpsertSecondFactorPreference: fmt.Sprintf("INSERT INTO %s (username, second_factor_method) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET second_factor_method=$2", userPreferencesTableName), - - sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=$1)", identityVerificationTokensTableName), - sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES ($1)", identityVerificationTokensTableName), - sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=$1", identityVerificationTokensTableName), - - sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=$1", totpSecretsTableName), - sqlUpsertTOTPSecret: fmt.Sprintf("INSERT INTO %s (username, secret) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET secret=$2", totpSecretsTableName), - sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=$1", totpSecretsTableName), - - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=$1", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, keyHandle, publicKey) VALUES ($1, $2, $3) ON CONFLICT (username) DO UPDATE SET keyHandle=$2, publicKey=$3", u2fDeviceHandlesTableName), - - sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES ($1, $2, $3)", authenticationLogsTableName), - sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>$1 AND username=$2 ORDER BY time DESC", authenticationLogsTableName), - - sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'", - - sqlConfigSetValue: fmt.Sprintf("INSERT INTO %s (category, key_name, value) VALUES ($1, $2, $3) ON CONFLICT (category, key_name) DO UPDATE SET value=$3", configTableName), - sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=$1 AND key_name=$2", configTableName), - }, - } - - args := make([]string, 0) - if configuration.Username != "" { - args = append(args, fmt.Sprintf("user='%s'", configuration.Username)) - } - - if configuration.Password != "" { - args = append(args, fmt.Sprintf("password='%s'", configuration.Password)) - } - - if configuration.Host != "" { - args = append(args, fmt.Sprintf("host=%s", configuration.Host)) - } - - if configuration.Port > 0 { - args = append(args, fmt.Sprintf("port=%d", configuration.Port)) - } - - if configuration.Database != "" { - args = append(args, fmt.Sprintf("dbname=%s", configuration.Database)) - } - - if configuration.SSLMode != "" { - args = append(args, fmt.Sprintf("sslmode=%s", configuration.SSLMode)) - } - - args = append(args, fmt.Sprintf("connect_timeout=%d", int32(configuration.Timeout/time.Second))) - connectionString := strings.Join(args, " ") - - db, err := sql.Open("pgx", connectionString) - if err != nil { - provider.log.Fatalf("Unable to connect to SQL database: %v", err) - } - - if err := provider.initialize(db); err != nil { - provider.log.Fatalf("Unable to initialize SQL database: %v", err) - } - - return &provider -} diff --git a/internal/storage/provider.go b/internal/storage/provider.go index d1374ec1e..d95881253 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -1,28 +1,45 @@ package storage import ( + "context" "time" "github.com/authelia/authelia/v4/internal/models" ) -// Provider is an interface providing storage capabilities for -// persisting any kind of data related to Authelia. +// Provider is an interface providing storage capabilities for persisting any kind of data related to Authelia. type Provider interface { - LoadPreferred2FAMethod(username string) (string, error) - SavePreferred2FAMethod(username string, method string) error + models.StartupCheck - FindIdentityVerificationToken(token string) (bool, error) - SaveIdentityVerificationToken(token string) error - RemoveIdentityVerificationToken(token string) error + RegulatorProvider - SaveTOTPSecret(username string, secret string) error - LoadTOTPSecret(username string) (string, error) - DeleteTOTPSecret(username string) error + SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) + LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) + LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) - SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error - LoadU2FDeviceHandle(username string) (keyHandle []byte, publicKey []byte, err error) + SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) + RemoveIdentityVerification(ctx context.Context, jti string) (err error) + FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) - AppendAuthenticationLog(attempt models.AuthenticationAttempt) error - LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) + SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) + DeleteTOTPConfiguration(ctx context.Context, username string) (err error) + LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) + + SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) + LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) + + SchemaTables(ctx context.Context) (tables []string, err error) + SchemaVersion(ctx context.Context) (version int, err error) + SchemaMigrate(ctx context.Context, up bool, version int) (err error) + SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error) + + SchemaLatestVersion() (version int, err error) + SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) + SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) +} + +// RegulatorProvider is an interface providing storage capabilities for persisting any kind of data related to the regulator. +type RegulatorProvider interface { + AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) + LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) } diff --git a/internal/storage/provider_mock.go b/internal/storage/provider_mock.go index 8414e6dbe..d9a773ab9 100644 --- a/internal/storage/provider_mock.go +++ b/internal/storage/provider_mock.go @@ -1,10 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: internal/storage/provider.go +// Source: ./internal/storage/provider.go -// Package storage is a generated GoMock package. package storage import ( + context "context" reflect "reflect" time "time" @@ -13,199 +13,331 @@ import ( models "github.com/authelia/authelia/v4/internal/models" ) -// MockProvider is a mock of Provider interface +// MockProvider is a mock of Provider interface. type MockProvider struct { ctrl *gomock.Controller recorder *MockProviderMockRecorder } -// MockProviderMockRecorder is the mock recorder for MockProvider +// MockProviderMockRecorder is the mock recorder for MockProvider. type MockProviderMockRecorder struct { mock *MockProvider } -// NewMockProvider creates a new mock instance +// NewMockProvider creates a new mock instance. func NewMockProvider(ctrl *gomock.Controller) *MockProvider { mock := &MockProvider{ctrl: ctrl} mock.recorder = &MockProviderMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } -// LoadPreferred2FAMethod mocks base method -func (m *MockProvider) LoadPreferred2FAMethod(username string) (string, error) { +// AppendAuthenticationLog mocks base method. +func (m *MockProvider) AppendAuthenticationLog(arg0 context.Context, arg1 models.AuthenticationAttempt) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", username) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod -func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), username) -} - -// SavePreferred2FAMethod mocks base method -func (m *MockProvider) SavePreferred2FAMethod(username, method string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SavePreferred2FAMethod", username, method) + ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } -// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod -func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(username, method interface{}) *gomock.Call { +// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog. +func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), username, method) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0, arg1) } -// FindIdentityVerificationToken mocks base method -func (m *MockProvider) FindIdentityVerificationToken(token string) (bool, error) { +// DeleteTOTPConfiguration mocks base method. +func (m *MockProvider) DeleteTOTPConfiguration(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FindIdentityVerificationToken", token) + ret := m.ctrl.Call(m, "DeleteTOTPConfiguration", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteTOTPConfiguration indicates an expected call of DeleteTOTPConfiguration. +func (mr *MockProviderMockRecorder) DeleteTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPConfiguration), arg0, arg1) +} + +// FindIdentityVerification mocks base method. +func (m *MockProvider) FindIdentityVerification(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindIdentityVerification", arg0, arg1) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken -func (mr *MockProviderMockRecorder) FindIdentityVerificationToken(token interface{}) *gomock.Call { +// FindIdentityVerification indicates an expected call of FindIdentityVerification. +func (mr *MockProviderMockRecorder) FindIdentityVerification(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerificationToken), token) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerification", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerification), arg0, arg1) } -// SaveIdentityVerificationToken mocks base method -func (m *MockProvider) SaveIdentityVerificationToken(token string) error { +// LoadAuthenticationLogs mocks base method. +func (m *MockProvider) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]models.AuthenticationAttempt, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", token) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken -func (mr *MockProviderMockRecorder) SaveIdentityVerificationToken(token interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerificationToken), token) -} - -// RemoveIdentityVerificationToken mocks base method -func (m *MockProvider) RemoveIdentityVerificationToken(token string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", token) - ret0, _ := ret[0].(error) - return ret0 -} - -// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken -func (mr *MockProviderMockRecorder) RemoveIdentityVerificationToken(token interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerificationToken), token) -} - -// SaveTOTPSecret mocks base method -func (m *MockProvider) SaveTOTPSecret(username, secret string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveTOTPSecret", username, secret) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveTOTPSecret indicates an expected call of SaveTOTPSecret -func (mr *MockProviderMockRecorder) SaveTOTPSecret(username, secret interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockProvider)(nil).SaveTOTPSecret), username, secret) -} - -// LoadTOTPSecret mocks base method -func (m *MockProvider) LoadTOTPSecret(username string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadTOTPSecret", username) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadTOTPSecret indicates an expected call of LoadTOTPSecret -func (mr *MockProviderMockRecorder) LoadTOTPSecret(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockProvider)(nil).LoadTOTPSecret), username) -} - -// DeleteTOTPSecret mocks base method -func (m *MockProvider) DeleteTOTPSecret(username string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteTOTPSecret", username) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteTOTPSecret indicates an expected call of DeleteTOTPSecret -func (mr *MockProviderMockRecorder) DeleteTOTPSecret(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPSecret", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPSecret), username) -} - -// SaveU2FDeviceHandle mocks base method -func (m *MockProvider) SaveU2FDeviceHandle(username string, keyHandle, publicKey []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", username, keyHandle, publicKey) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle -func (mr *MockProviderMockRecorder) SaveU2FDeviceHandle(username, keyHandle, publicKey interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).SaveU2FDeviceHandle), username, keyHandle, publicKey) -} - -// LoadU2FDeviceHandle mocks base method -func (m *MockProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", username) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].([]byte) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle -func (mr *MockProviderMockRecorder) LoadU2FDeviceHandle(username interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).LoadU2FDeviceHandle), username) -} - -// AppendAuthenticationLog mocks base method -func (m *MockProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppendAuthenticationLog", attempt) - ret0, _ := ret[0].(error) - return ret0 -} - -// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog -func (mr *MockProviderMockRecorder) AppendAuthenticationLog(attempt interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), attempt) -} - -// LoadLatestAuthenticationLogs mocks base method -func (m *MockProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", username, fromDate) + ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([]models.AuthenticationAttempt) ret1, _ := ret[1].(error) return ret0, ret1 } -// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs -func (mr *MockProviderMockRecorder) LoadLatestAuthenticationLogs(username, fromDate interface{}) *gomock.Call { +// LoadAuthenticationLogs indicates an expected call of LoadAuthenticationLogs. +func (mr *MockProviderMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadLatestAuthenticationLogs), username, fromDate) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) +} + +// LoadPreferred2FAMethod mocks base method. +func (m *MockProvider) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod. +func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), arg0, arg1) +} + +// LoadTOTPConfiguration mocks base method. +func (m *MockProvider) LoadTOTPConfiguration(arg0 context.Context, arg1 string) (*models.TOTPConfiguration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadTOTPConfiguration", arg0, arg1) + ret0, _ := ret[0].(*models.TOTPConfiguration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadTOTPConfiguration indicates an expected call of LoadTOTPConfiguration. +func (mr *MockProviderMockRecorder) LoadTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).LoadTOTPConfiguration), arg0, arg1) +} + +// LoadU2FDevice mocks base method. +func (m *MockProvider) LoadU2FDevice(arg0 context.Context, arg1 string) (*models.U2FDevice, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadU2FDevice", arg0, arg1) + ret0, _ := ret[0].(*models.U2FDevice) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadU2FDevice indicates an expected call of LoadU2FDevice. +func (mr *MockProviderMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockProvider)(nil).LoadU2FDevice), arg0, arg1) +} + +// LoadUserInfo mocks base method. +func (m *MockProvider) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadUserInfo", arg0, arg1) + ret0, _ := ret[0].(models.UserInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadUserInfo indicates an expected call of LoadUserInfo. +func (mr *MockProviderMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockProvider)(nil).LoadUserInfo), arg0, arg1) +} + +// RemoveIdentityVerification mocks base method. +func (m *MockProvider) RemoveIdentityVerification(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification. +func (mr *MockProviderMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerification), arg0, arg1) +} + +// SaveIdentityVerification mocks base method. +func (m *MockProvider) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveIdentityVerification", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveIdentityVerification indicates an expected call of SaveIdentityVerification. +func (mr *MockProviderMockRecorder) SaveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerification), arg0, arg1) +} + +// SavePreferred2FAMethod mocks base method. +func (m *MockProvider) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SavePreferred2FAMethod", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod. +func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), arg0, arg1, arg2) +} + +// SaveTOTPConfiguration mocks base method. +func (m *MockProvider) SaveTOTPConfiguration(arg0 context.Context, arg1 models.TOTPConfiguration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveTOTPConfiguration", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveTOTPConfiguration indicates an expected call of SaveTOTPConfiguration. +func (mr *MockProviderMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).SaveTOTPConfiguration), arg0, arg1) +} + +// SaveU2FDevice mocks base method. +func (m *MockProvider) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveU2FDevice", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveU2FDevice indicates an expected call of SaveU2FDevice. +func (mr *MockProviderMockRecorder) SaveU2FDevice(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDevice", reflect.TypeOf((*MockProvider)(nil).SaveU2FDevice), arg0, arg1) +} + +// SchemaLatestVersion mocks base method. +func (m *MockProvider) SchemaLatestVersion() (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaLatestVersion") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaLatestVersion indicates an expected call of SchemaLatestVersion. +func (mr *MockProviderMockRecorder) SchemaLatestVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaLatestVersion", reflect.TypeOf((*MockProvider)(nil).SchemaLatestVersion)) +} + +// SchemaMigrate mocks base method. +func (m *MockProvider) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaMigrate", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SchemaMigrate indicates an expected call of SchemaMigrate. +func (mr *MockProviderMockRecorder) SchemaMigrate(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrate", reflect.TypeOf((*MockProvider)(nil).SchemaMigrate), arg0, arg1, arg2) +} + +// SchemaMigrationHistory mocks base method. +func (m *MockProvider) SchemaMigrationHistory(arg0 context.Context) ([]models.Migration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaMigrationHistory", arg0) + ret0, _ := ret[0].([]models.Migration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaMigrationHistory indicates an expected call of SchemaMigrationHistory. +func (mr *MockProviderMockRecorder) SchemaMigrationHistory(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationHistory", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationHistory), arg0) +} + +// SchemaMigrationsDown mocks base method. +func (m *MockProvider) SchemaMigrationsDown(arg0 context.Context, arg1 int) ([]SchemaMigration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaMigrationsDown", arg0, arg1) + ret0, _ := ret[0].([]SchemaMigration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaMigrationsDown indicates an expected call of SchemaMigrationsDown. +func (mr *MockProviderMockRecorder) SchemaMigrationsDown(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsDown", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsDown), arg0, arg1) +} + +// SchemaMigrationsUp mocks base method. +func (m *MockProvider) SchemaMigrationsUp(arg0 context.Context, arg1 int) ([]SchemaMigration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaMigrationsUp", arg0, arg1) + ret0, _ := ret[0].([]SchemaMigration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaMigrationsUp indicates an expected call of SchemaMigrationsUp. +func (mr *MockProviderMockRecorder) SchemaMigrationsUp(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsUp", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsUp), arg0, arg1) +} + +// SchemaTables mocks base method. +func (m *MockProvider) SchemaTables(arg0 context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaTables", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaTables indicates an expected call of SchemaTables. +func (mr *MockProviderMockRecorder) SchemaTables(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaTables", reflect.TypeOf((*MockProvider)(nil).SchemaTables), arg0) +} + +// SchemaVersion mocks base method. +func (m *MockProvider) SchemaVersion(arg0 context.Context) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SchemaVersion", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SchemaVersion indicates an expected call of SchemaVersion. +func (mr *MockProviderMockRecorder) SchemaVersion(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaVersion", reflect.TypeOf((*MockProvider)(nil).SchemaVersion), arg0) +} + +// StartupCheck mocks base method. +func (m *MockProvider) StartupCheck() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartupCheck") + ret0, _ := ret[0].(error) + return ret0 +} + +// StartupCheck indicates an expected call of StartupCheck. +func (mr *MockProviderMockRecorder) StartupCheck() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockProvider)(nil).StartupCheck)) } diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 8bd27a7e8..f136e4c6e 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -1,173 +1,199 @@ package storage import ( + "context" "database/sql" - "encoding/base64" + "errors" "fmt" "time" + "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" + "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/models" - "github.com/authelia/authelia/v4/internal/utils" ) +// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. +func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) { + db, err := sqlx.Open(driverName, dataSourceName) + + provider = SQLProvider{ + name: name, + driverName: driverName, + db: db, + log: logging.Logger(), + errOpen: err, + + sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), + sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), + + sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification), + sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification), + sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification), + + sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), + sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), + sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), + + sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices), + sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices), + + sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences), + sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), + sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences), + + sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), + sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), + sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), + + sqlFmtRenameTable: queryFmtRenameTable, + } + + return provider +} + // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { - db *sql.DB - log *logrus.Logger - name string + db *sqlx.DB + log *logrus.Logger + name string + driverName string + errOpen error - sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string - sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string + // Table: authentication_logs. + sqlInsertAuthenticationAttempt string + sqlSelectAuthenticationAttemptsByUsername string - sqlGetPreferencesByUsername string - sqlUpsertSecondFactorPreference string + // Table: identity_verification_tokens. + sqlInsertIdentityVerification string + sqlDeleteIdentityVerification string + sqlSelectExistsIdentityVerification string - sqlTestIdentityVerificationTokenExistence string - sqlInsertIdentityVerificationToken string - sqlDeleteIdentityVerificationToken string + // Table: totp_configurations. + sqlUpsertTOTPConfig string + sqlDeleteTOTPConfig string + sqlSelectTOTPConfig string - sqlGetTOTPSecretByUsername string - sqlUpsertTOTPSecret string - sqlDeleteTOTPSecret string + // Table: u2f_devices. + sqlUpsertU2FDevice string + sqlSelectU2FDevice string - sqlGetU2FDeviceHandleByUsername string - sqlUpsertU2FDeviceHandle string + // Table: user_preferences. + sqlUpsertPreferred2FAMethod string + sqlSelectPreferred2FAMethod string + sqlSelectUserInfo string - sqlInsertAuthenticationLog string - sqlGetLatestAuthenticationLogs string + // Table: migrations. + sqlInsertMigration string + sqlSelectMigrations string + sqlSelectLatestMigration string - sqlGetExistingTables string - - sqlConfigSetValue string - sqlConfigGetValue string + // Utility. + sqlSelectExistingTables string + sqlFmtRenameTable string } -func (p *SQLProvider) initialize(db *sql.DB) error { - p.db = db - p.log = logging.Logger() - - return p.upgrade() -} - -func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) { - rows, err := p.db.Query(p.sqlGetExistingTables) - if err != nil { - return version, tables, err +// StartupCheck implements the provider startup check interface. +func (p *SQLProvider) StartupCheck() (err error) { + if p.errOpen != nil { + return p.errOpen } - defer rows.Close() - - var table string - - for rows.Next() { - err := rows.Scan(&table) - if err != nil { - return version, tables, err + // TODO: Decide if this is needed, or if it should be configurable. + for i := 0; i < 19; i++ { + err = p.db.Ping() + if err == nil { + break } - tables = append(tables, table) + time.Sleep(time.Millisecond * 500) } - if utils.IsStringInSlice(configTableName, tables) { - rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version") - if err != nil { - return version, tables, err - } - - for rows.Next() { - err := rows.Scan(&version) - if err != nil { - return version, tables, err - } - } - } - - return version, tables, nil -} - -func (p *SQLProvider) upgrade() error { - p.log.Debug("Storage schema is being checked to verify it is up to date") - - version, tables, err := p.getSchemaBasicDetails() if err != nil { return err } - if version < storageSchemaCurrentVersion { - p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion) + p.log.Infof("Storage schema is being checked for updates") - tx, err := p.db.Begin() - if err != nil { - return err - } + ctx := context.Background() - switch version { - case 0: - err := p.upgradeSchemaToVersion001(tx, tables) - if err != nil { - return p.handleUpgradeFailure(tx, 1, err) - } + err = p.SchemaMigrate(ctx, true, SchemaLatest) - fallthrough - default: - err := tx.Commit() - if err != nil { - return err - } - - p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion) - } - } else { - p.log.Debug("Storage schema is up to date") + switch err { + case ErrSchemaAlreadyUpToDate: + p.log.Infof("Storage schema is already up to date") + return nil + case nil: + return nil + default: + return err } - - return nil -} - -func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error { - rollbackErr := tx.Rollback() - formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err) - - if rollbackErr != nil { - return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr) - } - - return formattedErr -} - -// LoadPreferred2FAMethod load the preferred method for 2FA from the database. -func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) { - var method string - - rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username) - if err != nil { - return "", err - } - defer rows.Close() - - if !rows.Next() { - return "", nil - } - - err = rows.Scan(&method) - - return method, err } // SavePreferred2FAMethod save the preferred method for 2FA to the database. -func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error { - _, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method) +func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method) + return err } -// FindIdentityVerificationToken look for an identity verification token in the database. -func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) { - var found bool +// LoadPreferred2FAMethod load the preferred method for 2FA from the database. +func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { + err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) - err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found) + switch err { + case sql.ErrNoRows: + return "", nil + case nil: + return method, err + default: + return "", err + } +} + +// LoadUserInfo loads the models.UserInfo from the database. +func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) { + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) + + switch { + case err == nil: + return info, nil + case errors.Is(err, sql.ErrNoRows): + _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]) + if err != nil { + return models.UserInfo{}, err + } + + err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username) + if err != nil { + return models.UserInfo{}, err + } + + return info, nil + default: + return models.UserInfo{}, err + } +} + +// SaveIdentityVerification save an identity verification record to the database. +func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token) + + return err +} + +// RemoveIdentityVerification remove an identity verification record from the database. +func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token) + + return err +} + +// FindIdentityVerification checks if an identity verification record is in the database and active. +func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { + err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti) if err != nil { return false, err } @@ -175,105 +201,94 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) return found, nil } -// SaveIdentityVerificationToken save an identity verification token in the database. -func (p *SQLProvider) SaveIdentityVerificationToken(token string) error { - _, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token) +// SaveTOTPConfiguration save a TOTP config of a given user in the database. +func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) { + // TODO: Encrypt config.Secret here. + _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, + config.Username, + config.Algorithm, + config.Digits, + config.Period, + config.Secret, + ) + return err } -// RemoveIdentityVerificationToken remove an identity verification token from the database. -func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error { - _, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token) +// DeleteTOTPConfiguration delete a TOTP secret from the database given a username. +func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username) + return err } -// SaveTOTPSecret save a TOTP secret of a given user in the database. -func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error { - _, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret) - return err -} +// LoadTOTPConfiguration load a TOTP secret given a username from the database. +func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) { + config = &models.TOTPConfiguration{} -// LoadTOTPSecret load a TOTP secret given a username from the database. -func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { - var secret string - if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil { + err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config) + if err != nil { if err == sql.ErrNoRows { - return "", ErrNoTOTPSecret + return nil, ErrNoTOTPSecret } - return "", err + return nil, err } - return secret, nil + // TODO: Decrypt config.Secret here. + return config, nil } -// DeleteTOTPSecret delete a TOTP secret from the database given a username. -func (p *SQLProvider) DeleteTOTPSecret(username string) error { - _, err := p.db.Exec(p.sqlDeleteTOTPSecret, username) - return err -} - -// SaveU2FDeviceHandle save a registered U2F device registration blob. -func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error { - _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, - username, - base64.StdEncoding.EncodeToString(keyHandle), - base64.StdEncoding.EncodeToString(publicKey)) +// SaveU2FDevice saves a registered U2F device. +func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey) return err } -// LoadU2FDeviceHandle load a U2F device registration blob for a given username. -func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) { - var keyHandleBase64, publicKeyBase64 string - if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil { +// LoadU2FDevice loads a U2F device registration for a given username. +func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) { + device = &models.U2FDevice{ + Username: username, + } + + err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username) + if err != nil { if err == sql.ErrNoRows { - return nil, nil, ErrNoU2FDeviceHandle + return nil, ErrNoU2FDeviceHandle } - return nil, nil, err + return nil, err } - keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64) - - if err != nil { - return nil, nil, err - } - - publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) - - if err != nil { - return nil, nil, err - } - - return keyHandle, publicKey, nil + return device, nil } // AppendAuthenticationLog append a mark to the authentication log. -func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error { - _, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix()) +func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username) return err } -// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log. -func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { - var t int64 - - rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username) - +// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. +func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page) if err != nil { return nil, err } - attempts := make([]models.AuthenticationAttempt, 0, 10) + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + attempts = make([]models.AuthenticationAttempt, 0, limit) for rows.Next() { - attempt := models.AuthenticationAttempt{ - Username: username, - } - err = rows.Scan(&attempt.Successful, &t) - attempt.Time = time.Unix(t, 0) + var attempt models.AuthenticationAttempt + err = rows.StructScan(&attempt) if err != nil { return nil, err } diff --git a/internal/storage/sql_provider_backend_mysql.go b/internal/storage/sql_provider_backend_mysql.go new file mode 100644 index 000000000..dfb7ec179 --- /dev/null +++ b/internal/storage/sql_provider_backend_mysql.go @@ -0,0 +1,53 @@ +package storage + +import ( + "fmt" + "time" + + _ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string. + + "github.com/authelia/authelia/v4/internal/configuration/schema" +) + +// MySQLProvider is a MySQL provider. +type MySQLProvider struct { + SQLProvider +} + +// NewMySQLProvider a MySQL provider. +func NewMySQLProvider(config schema.MySQLStorageConfiguration) (provider *MySQLProvider) { + provider = &MySQLProvider{ + SQLProvider: NewSQLProvider(providerMySQL, providerMySQL, dataSourceNameMySQL(config)), + } + + // All providers have differing SELECT existing table statements. + provider.sqlSelectExistingTables = queryMySQLSelectExistingTables + + // Specific alterations to this provider. + provider.sqlFmtRenameTable = queryFmtMySQLRenameTable + + return provider +} + +func dataSourceNameMySQL(config schema.MySQLStorageConfiguration) (dataSourceName string) { + dataSourceName = fmt.Sprintf("%s:%s", config.Username, config.Password) + + if dataSourceName != "" { + dataSourceName += "@" + } + + address := config.Host + if config.Port > 0 { + address += fmt.Sprintf(":%d", config.Port) + } + + dataSourceName += fmt.Sprintf("tcp(%s)", address) + if config.Database != "" { + dataSourceName += fmt.Sprintf("/%s", config.Database) + } + + dataSourceName += "?" + dataSourceName += fmt.Sprintf("timeout=%ds&multiStatements=true&parseTime=true", int32(config.Timeout/time.Second)) + + return dataSourceName +} diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go new file mode 100644 index 000000000..5a79afe84 --- /dev/null +++ b/internal/storage/sql_provider_backend_postgres.go @@ -0,0 +1,72 @@ +package storage + +import ( + "fmt" + "strings" + "time" + + _ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string. + + "github.com/authelia/authelia/v4/internal/configuration/schema" +) + +// PostgreSQLProvider is a PostgreSQL provider. +type PostgreSQLProvider struct { + SQLProvider +} + +// NewPostgreSQLProvider a PostgreSQL provider. +func NewPostgreSQLProvider(config schema.PostgreSQLStorageConfiguration) (provider *PostgreSQLProvider) { + provider = &PostgreSQLProvider{ + SQLProvider: NewSQLProvider(providerPostgres, "pgx", dataSourceNamePostgreSQL(config)), + } + + // All providers have differing SELECT existing table statements. + provider.sqlSelectExistingTables = queryPostgreSelectExistingTables + + // Specific alterations to this provider. + // PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead. + provider.sqlUpsertU2FDevice = fmt.Sprintf(queryFmtPostgresUpsertU2FDevice, tableU2FDevices) + provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations) + provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences) + + // PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders. + provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable) + provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod) + provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo) + provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification) + provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) + provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification) + provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) + provider.sqlUpsertTOTPConfig = provider.db.Rebind(provider.sqlUpsertTOTPConfig) + provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig) + provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice) + provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt) + provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername) + provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration) + + return provider +} + +func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) { + args := []string{ + fmt.Sprintf("user='%s'", config.Username), + fmt.Sprintf("password='%s'", config.Password), + } + + if config.Host != "" { + args = append(args, fmt.Sprintf("host=%s", config.Host)) + } + + if config.Port > 0 { + args = append(args, fmt.Sprintf("port=%d", config.Port)) + } + + if config.Database != "" { + args = append(args, fmt.Sprintf("dbname=%s", config.Database)) + } + + args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second))) + + return strings.Join(args, " ") +} diff --git a/internal/storage/sql_provider_backend_sqlite.go b/internal/storage/sql_provider_backend_sqlite.go new file mode 100644 index 000000000..b11b82239 --- /dev/null +++ b/internal/storage/sql_provider_backend_sqlite.go @@ -0,0 +1,22 @@ +package storage + +import ( + _ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string. +) + +// SQLiteProvider is a SQLite3 provider. +type SQLiteProvider struct { + SQLProvider +} + +// NewSQLiteProvider constructs a SQLite provider. +func NewSQLiteProvider(path string) (provider *SQLiteProvider) { + provider = &SQLiteProvider{ + SQLProvider: NewSQLProvider(providerSQLite, "sqlite3", path), + } + + // All providers have differing SELECT existing table statements. + provider.sqlSelectExistingTables = querySQLiteSelectExistingTables + + return provider +} diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go new file mode 100644 index 000000000..f11585d7d --- /dev/null +++ b/internal/storage/sql_provider_queries.go @@ -0,0 +1,125 @@ +package storage + +const ( + queryFmtSelectMigrations = ` + SELECT id, applied, version_before, version_after, application_version + FROM %s;` + + queryFmtSelectLatestMigration = ` + SELECT id, applied, version_before, version_after, application_version + FROM %s + ORDER BY id DESC + LIMIT 1;` + + queryFmtInsertMigration = ` + INSERT INTO %s (applied, version_before, version_after, application_version) + VALUES (?, ?, ?, ?);` +) + +const ( + queryMySQLSelectExistingTables = ` + SELECT table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' AND table_schema = database();` + + queryPostgreSelectExistingTables = ` + SELECT table_name + FROM information_schema.tables + WHERE table_type = 'BASE TABLE' AND table_schema = 'public';` + + querySQLiteSelectExistingTables = ` + SELECT name + FROM sqlite_master + WHERE type = 'table';` +) + +const ( + queryFmtSelectUserInfo = ` + SELECT second_factor_method, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_totp, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_u2f + FROM %s + WHERE username = ?;` + + queryFmtSelectPreferred2FAMethod = ` + SELECT second_factor_method + FROM %s + WHERE username = ?;` + + queryFmtUpsertPreferred2FAMethod = ` + REPLACE INTO %s (username, second_factor_method) + VALUES (?, ?);` + + queryFmtPostgresUpsertPreferred2FAMethod = ` + INSERT INTO %s (username, second_factor_method) + VALUES ($1, $2) + ON CONFLICT (username) + DO UPDATE SET second_factor_method = $2;` +) + +const ( + queryFmtSelectExistsIdentityVerification = ` + SELECT EXISTS ( + SELECT id + FROM %s + WHERE token = ? + );` + + queryFmtInsertIdentityVerification = ` + INSERT INTO %s (token) + VALUES (?);` + + queryFmtDeleteIdentityVerification = ` + DELETE FROM %s + WHERE token = ?;` +) + +const ( + queryFmtSelectTOTPConfiguration = ` + SELECT id, username, algorithm, digits, totp_period, secret + FROM %s + WHERE username = ?;` + + queryFmtUpsertTOTPConfiguration = ` + REPLACE INTO %s (username, algorithm, digits, totp_period, secret) + VALUES (?, ?, ?, ?, ?);` + + queryFmtPostgresUpsertTOTPConfiguration = ` + INSERT INTO %s (username, algorithm, digits, totp_period, secret) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (username) + DO UPDATE SET algorithm = $2, digits = $3, totp_period = $4, secret = $5;` + + queryFmtDeleteTOTPConfiguration = ` + DELETE FROM %s + WHERE username = ?;` +) + +const ( + queryFmtSelectU2FDevice = ` + SELECT key_handle, public_key + FROM %s + WHERE username = ?;` + + queryFmtUpsertU2FDevice = ` + REPLACE INTO %s (username, key_handle, public_key) + VALUES (?, ?, ?);` + + queryFmtPostgresUpsertU2FDevice = ` + INSERT INTO %s (username, key_handle, public_key) + VALUES ($1, $2, $3) + ON CONFLICT (username) + DO UPDATE SET key_handle=$2, public_key=$3;` +) + +const ( + queryFmtInsertAuthenticationLogEntry = ` + INSERT INTO %s (time, successful, username) + VALUES (?, ?, ?);` + + queryFmtSelect1FAAuthenticationLogEntryByUsername = ` + SELECT time, successful, username + FROM %s + WHERE time > ? AND username = ? + ORDER BY time DESC + LIMIT ? + OFFSET ?;` +) diff --git a/internal/storage/sql_provider_queries_special.go b/internal/storage/sql_provider_queries_special.go new file mode 100644 index 000000000..7e9ba24dd --- /dev/null +++ b/internal/storage/sql_provider_queries_special.go @@ -0,0 +1,109 @@ +package storage + +const ( + queryFmtDropTableIfExists = `DROP TABLE IF EXISTS %s;` + + queryFmtRenameTable = ` + ALTER TABLE %s + RENAME TO %s;` + + queryFmtMySQLRenameTable = ` + ALTER TABLE %s + RENAME %s;` +) + +// Pre1 migration constants. +const ( + queryFmtPre1To1SelectAuthenticationLogs = ` + SELECT username, successful, time + FROM %s + ORDER BY time ASC + LIMIT 100 OFFSET ?;` + + queryFmtPre1To1InsertAuthenticationLogs = ` + INSERT INTO %s (username, successful, time) + VALUES (?, ?, ?);` + + queryFmtPre1InsertUserPreferencesFromSelect = ` + INSERT INTO %s (username, second_factor_method) + SELECT username, second_factor_method + FROM %s + ORDER BY username ASC;` + + queryFmtPre1SelectTOTPConfigurations = ` + SELECT username, secret + FROM %s + ORDER BY username ASC;` + + queryFmtPre1InsertTOTPConfiguration = ` + INSERT INTO %s (username, secret) + VALUES (?, ?);` + + queryFmtPre1To1SelectU2FDevices = ` + SELECT username, keyHandle, publicKey + FROM %s + ORDER BY username ASC;` + + queryFmtPre1To1InsertU2FDevice = ` + INSERT INTO %s (username, key_handle, public_key) + VALUES (?, ?, ?);` + + queryFmt1ToPre1InsertAuthenticationLogs = ` + INSERT INTO %s (username, successful, time) + VALUES (?, ?, ?);` + + queryFmt1ToPre1SelectAuthenticationLogs = ` + SELECT username, successful, time + FROM %s + ORDER BY id ASC + LIMIT 100 OFFSET ?;` + + queryFmt1ToPre1SelectU2FDevices = ` + SELECT username, key_handle, public_key + FROM %s + ORDER BY username ASC;` + + queryFmt1ToPre1InsertU2FDevice = ` + INSERT INTO %s (username, keyHandle, publicKey) + VALUES (?, ?, ?);` + + queryCreatePre1 = ` + CREATE TABLE user_preferences ( + username VARCHAR(100), + second_factor_method VARCHAR(11), + PRIMARY KEY (username) + ); + + CREATE TABLE identity_verification_tokens ( + token VARCHAR(512) + ); + + CREATE TABLE totp_secrets ( + username VARCHAR(100), + secret VARCHAR(64), + PRIMARY KEY (username) + ); + + CREATE TABLE u2f_devices ( + username VARCHAR(100), + keyHandle TEXT, + publicKey TEXT, + PRIMARY KEY (username) + ); + + CREATE TABLE authentication_logs ( + username VARCHAR(100), + successful BOOL, + time INTEGER + ); + + CREATE TABLE config ( + category VARCHAR(32) NOT NULL, + key_name VARCHAR(32) NOT NULL, + value TEXT, + PRIMARY KEY (category, key_name) + ); + + INSERT INTO config (category, key_name, value) + VALUES ('schema', 'version', '1');` +) diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go new file mode 100644 index 000000000..6abb5ff58 --- /dev/null +++ b/internal/storage/sql_provider_schema.go @@ -0,0 +1,327 @@ +package storage + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/authelia/authelia/v4/internal/models" + "github.com/authelia/authelia/v4/internal/utils" +) + +// SchemaTables returns a list of tables. +func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables) + if err != nil { + return tables, err + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + var table string + + for rows.Next() { + err = rows.Scan(&table) + if err != nil { + return []string{}, err + } + + tables = append(tables, table) + } + + return tables, nil +} + +// SchemaVersion returns the version of the schema. +func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error) { + tables, err := p.SchemaTables(ctx) + if err != nil { + return -2, err + } + + if len(tables) == 0 { + return 0, nil + } + + if utils.IsStringInSlice(tableMigrations, tables) { + migration, err := p.schemaLatestMigration(ctx) + if err != nil { + return -2, err + } + + return migration.After, nil + } + + if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) && + utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) && + utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) { + return -1, nil + } + + // TODO: Decide if we want to support external tables. + // return -2, ErrUnknownSchemaState + return 0, nil +} + +func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *models.Migration, err error) { + migration = &models.Migration{} + + err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration) + if err != nil { + return nil, err + } + + return migration, nil +} + +// SchemaMigrationHistory returns migration history rows. +func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error) { + rows, err := p.db.QueryxContext(ctx, p.sqlSelectMigrations) + if err != nil { + return nil, err + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + var migration models.Migration + + for rows.Next() { + err = rows.StructScan(&migration) + if err != nil { + return nil, err + } + + migrations = append(migrations, migration) + } + + return migrations, nil +} + +// SchemaMigrate migrates from the current version to the provided version. +func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) { + currentVersion, err := p.SchemaVersion(ctx) + if err != nil { + return err + } + + if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil { + return err + } + + return p.schemaMigrate(ctx, currentVersion, version) +} + +func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) { + migrations, err := loadMigrations(p.name, prior, target) + if err != nil { + return err + } + + if len(migrations) == 0 { + return ErrNoMigrationsFound + } + + switch { + case prior == -1: + p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After())) + + err = p.schemaMigratePre1To1(ctx) + if err != nil { + if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil { + return fmt.Errorf(errFmtFailedMigrationPre1, err) + } + + return fmt.Errorf(errFmtFailedMigrationPre1, err) + } + case target == -1: + p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1") + default: + p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) + } + + for _, migration := range migrations { + if prior == -1 && migration.Version == 1 { + // Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade. + continue + } + + err = p.schemaMigrateApply(ctx, migration) + if err != nil { + return p.schemaMigrateRollback(ctx, prior, migration.After(), err) + } + } + + switch { + case prior == -1: + p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After())) + case target == -1: + err = p.schemaMigrate1ToPre1(ctx) + if err != nil { + if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil { + return fmt.Errorf(errFmtFailedMigrationPre1, err) + } + + return fmt.Errorf(errFmtFailedMigrationPre1, err) + } + + p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1") + default: + p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After())) + } + + return nil +} + +func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) { + migrations, err := loadMigrations(p.name, after, prior) + if err != nil { + return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr) + } + + for _, migration := range migrations { + if prior == -1 && !migration.Up && migration.Version == 1 { + continue + } + + err = p.schemaMigrateApply(ctx, migration) + if err != nil { + return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr) + } + } + + if prior == -1 { + if err = p.schemaMigrate1ToPre1(ctx); err != nil { + return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr) + } + } + + return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr) +} + +func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration SchemaMigration) (err error) { + _, err = p.db.ExecContext(ctx, migration.Query) + if err != nil { + return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) + } + + // Skip the migration history insertion in a migration to v0. + if migration.Version == 1 && !migration.Up { + return nil + } + + return p.schemaMigrateFinalize(ctx, migration) +} + +func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration SchemaMigration) (err error) { + return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After()) +} + +func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version()) + if err != nil { + return err + } + + p.log.Debugf("Storage schema migrated from version %d to %d", before, after) + + return nil +} + +// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version. +func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) { + current, err := p.SchemaVersion(ctx) + if err != nil { + return migrations, err + } + + if version == 0 { + version = SchemaLatest + } + + if current >= version { + return migrations, ErrNoAvailableMigrations + } + + return loadMigrations(p.name, current, version) +} + +// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version. +func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) { + current, err := p.SchemaVersion(ctx) + if err != nil { + return migrations, err + } + + if current <= version { + return migrations, ErrNoAvailableMigrations + } + + return loadMigrations(p.name, current, version) +} + +// SchemaLatestVersion returns the latest version available for migration.. +func (p *SQLProvider) SchemaLatestVersion() (version int, err error) { + return latestMigrationVersion(p.name) +} + +func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) { + if targetVersion == currentVersion { + return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion) + } + + latest, err := latestMigrationVersion(providerName) + if err != nil { + return err + } + + if currentVersion > latest { + return fmt.Errorf(errFmtSchemaCurrentGreaterThanLatestKnown, latest) + } + + if up { + if targetVersion < currentVersion { + return fmt.Errorf(ErrFmtMigrateUpTargetLessThanCurrent, targetVersion, currentVersion) + } + + if targetVersion == SchemaLatest && latest == currentVersion { + return ErrSchemaAlreadyUpToDate + } + + if targetVersion != SchemaLatest && latest < targetVersion { + return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest) + } + } else { + if targetVersion < -1 { + return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion) + } + + if targetVersion > currentVersion { + return fmt.Errorf(ErrFmtMigrateDownTargetGreaterThanCurrent, targetVersion, currentVersion) + } + } + + return nil +} + +// SchemaVersionToString returns a version string given a version number. +func SchemaVersionToString(version int) (versionStr string) { + switch version { + case -2: + return "unknown" + case -1: + return "pre1" + case 0: + return "N/A" + default: + return strconv.Itoa(version) + } +} diff --git a/internal/storage/sql_provider_schema_pre1.go b/internal/storage/sql_provider_schema_pre1.go new file mode 100644 index 000000000..a8aad4509 --- /dev/null +++ b/internal/storage/sql_provider_schema_pre1.go @@ -0,0 +1,449 @@ +package storage + +import ( + "context" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + "github.com/authelia/authelia/v4/internal/models" + "github.com/authelia/authelia/v4/internal/utils" +) + +// schemaMigratePre1To1 takes the v1 migration and migrates to this version. +func (p *SQLProvider) schemaMigratePre1To1(ctx context.Context) (err error) { + migration, err := loadMigration(p.name, 1, true) + if err != nil { + return err + } + + // Get Tables list. + tables, err := p.SchemaTables(ctx) + if err != nil { + return err + } + + tablesRename := []string{ + tablePre1Config, + tablePre1TOTPSecrets, + tablePre1IdentityVerificationTokens, + tableU2FDevices, + tableUserPreferences, + tableAuthenticationLogs, + tableAlphaPreferences, + tableAlphaIdentityVerificationTokens, + tableAlphaAuthenticationLogs, + tableAlphaPreferencesTableName, + tableAlphaSecondFactorPreferences, + tableAlphaTOTPSecrets, + tableAlphaU2FDeviceHandles, + } + + if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil { + return err + } + + if _, err = p.db.ExecContext(ctx, migration.Query); err != nil { + return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) + } + + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect), + tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil { + return err + } + + if err = p.schemaMigratePre1To1AuthenticationLogs(ctx); err != nil { + return err + } + + if err = p.schemaMigratePre1To1U2F(ctx); err != nil { + return err + } + + if err = p.schemaMigratePre1To1TOTP(ctx); err != nil { + return err + } + + for _, table := range tablesRename { + if _, err = p.db.Exec(fmt.Sprintf(p.db.Rebind(queryFmtDropTableIfExists), tablePrefixBackup+table)); err != nil { + return err + } + } + + return p.schemaMigrateFinalizeAdvanced(ctx, -1, 1) +} + +func (p *SQLProvider) schemaMigratePre1Rename(ctx context.Context, tables, tablesRename []string) (err error) { + // Rename Tables and Indexes. + for _, table := range tables { + if !utils.IsStringInSlice(table, tablesRename) { + continue + } + + tableNew := tablePrefixBackup + table + + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil { + return err + } + + if p.name == providerPostgres { + if table == tableU2FDevices || table == tableUserPreferences { + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`, + tableNew, table, tableNew)); err != nil { + continue + } + } + } + } + + return nil +} + +func (p *SQLProvider) schemaMigratePre1To1Rollback(ctx context.Context, up bool) (err error) { + if up { + migration, err := loadMigration(p.name, 1, false) + if err != nil { + return err + } + + if _, err = p.db.ExecContext(ctx, migration.Query); err != nil { + return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) + } + } + + tables, err := p.SchemaTables(ctx) + if err != nil { + return err + } + + for _, table := range tables { + if !strings.HasPrefix(table, tablePrefixBackup) { + continue + } + + tableNew := strings.Replace(table, tablePrefixBackup, "", 1) + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil { + return err + } + + if p.name == providerPostgres && (tableNew == tableU2FDevices || tableNew == tableUserPreferences) { + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`, + tableNew, table, tableNew)); err != nil { + continue + } + } + } + + return nil +} + +func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogs(ctx context.Context) (err error) { + for page := 0; true; page++ { + attempts, err := p.schemaMigratePre1To1AuthenticationLogsGetRows(ctx, page) + if err != nil { + if err == sql.ErrNoRows { + break + } + + return err + } + + for _, attempt := range attempts { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time) + if err != nil { + return err + } + } + + if len(attempts) != 100 { + break + } + } + + return nil +} + +func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) { + rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100) + if err != nil { + return nil, err + } + + attempts = make([]models.AuthenticationAttempt, 0, 100) + + for rows.Next() { + var ( + username string + successful bool + timestamp int64 + ) + + err = rows.Scan(&username, &successful, ×tamp) + if err != nil { + return nil, err + } + + attempts = append(attempts, models.AuthenticationAttempt{Username: username, Successful: successful, Time: time.Unix(timestamp, 0)}) + } + + return attempts, nil +} + +func (p *SQLProvider) schemaMigratePre1To1TOTP(ctx context.Context) (err error) { + rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tablePre1TOTPSecrets)) + if err != nil { + return err + } + + var totpConfigs []models.TOTPConfiguration + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + for rows.Next() { + var username, secret string + + err = rows.Scan(&username, &secret) + if err != nil { + return err + } + + // TODO: Add encryption migration here. + encryptedSecret := "encrypted:" + secret + + totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: encryptedSecret}) + } + + for _, config := range totpConfigs { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, config.Secret) + if err != nil { + return err + } + } + + return nil +} + +func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) { + rows, err := p.db.Queryx(fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectU2FDevices), tablePrefixBackup+tableU2FDevices)) + if err != nil { + return err + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + var devices []models.U2FDevice + + for rows.Next() { + var username, keyHandleBase64, publicKeyBase64 string + + err = rows.Scan(&username, &keyHandleBase64, &publicKeyBase64) + if err != nil { + return err + } + + keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64) + if err != nil { + return err + } + + publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) + if err != nil { + return err + } + + devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey}) + } + + for _, device := range devices { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertU2FDevice), tableU2FDevices), device.Username, device.KeyHandle, device.PublicKey) + if err != nil { + return err + } + } + + return nil +} + +func (p *SQLProvider) schemaMigrate1ToPre1(ctx context.Context) (err error) { + tables, err := p.SchemaTables(ctx) + if err != nil { + return err + } + + tablesRename := []string{ + tableMigrations, + tableTOTPConfigurations, + tableIdentityVerification, + tableU2FDevices, + tableDUODevices, + tableUserPreferences, + tableAuthenticationLogs, + } + + if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil { + return err + } + + if _, err := p.db.ExecContext(ctx, queryCreatePre1); err != nil { + return err + } + + if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect), + tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil { + return err + } + + if err = p.schemaMigrate1ToPre1AuthenticationLogs(ctx); err != nil { + return err + } + + if err = p.schemaMigrate1ToPre1U2F(ctx); err != nil { + return err + } + + if err = p.schemaMigrate1ToPre1TOTP(ctx); err != nil { + return err + } + + queryFmtDropTableRebound := p.db.Rebind(queryFmtDropTableIfExists) + + for _, table := range tablesRename { + if _, err = p.db.Exec(fmt.Sprintf(queryFmtDropTableRebound, tablePrefixBackup+table)); err != nil { + return err + } + } + + return nil +} + +func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogs(ctx context.Context) (err error) { + for page := 0; true; page++ { + attempts, err := p.schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx, page) + if err != nil { + if err == sql.ErrNoRows { + break + } + + return err + } + + for _, attempt := range attempts { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time.Unix()) + if err != nil { + return err + } + } + + if len(attempts) != 100 { + break + } + } + + return nil +} + +func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) { + rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100) + if err != nil { + return nil, err + } + + attempts = make([]models.AuthenticationAttempt, 0, 100) + + var attempt models.AuthenticationAttempt + for rows.Next() { + err = rows.StructScan(&attempt) + if err != nil { + return nil, err + } + + attempts = append(attempts, attempt) + } + + return attempts, nil +} + +func (p *SQLProvider) schemaMigrate1ToPre1TOTP(ctx context.Context) (err error) { + rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tableTOTPConfigurations)) + if err != nil { + return err + } + + var totpConfigs []models.TOTPConfiguration + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + for rows.Next() { + var username, encryptedSecret string + + err = rows.Scan(&username, &encryptedSecret) + if err != nil { + return err + } + + // TODO: Fix. + // TODO: Add DECRYPTION migration here. + decryptedSecret := strings.Replace(encryptedSecret, "encrypted:", "", 1) + + totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: decryptedSecret}) + } + + for _, config := range totpConfigs { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret) + if err != nil { + return err + } + } + + return nil +} + +func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) { + rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectU2FDevices), tablePrefixBackup+tableU2FDevices)) + if err != nil { + return err + } + + defer func() { + if err := rows.Close(); err != nil { + p.log.Errorf(logFmtErrClosingConn, err) + } + }() + + var ( + devices []models.U2FDevice + device models.U2FDevice + ) + + for rows.Next() { + err = rows.StructScan(&device) + if err != nil { + return err + } + + devices = append(devices, device) + } + + for _, device := range devices { + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertU2FDevice), tableU2FDevices), device.Username, base64.StdEncoding.EncodeToString(device.KeyHandle), base64.StdEncoding.EncodeToString(device.PublicKey)) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/storage/sql_provider_schema_test.go b/internal/storage/sql_provider_schema_test.go new file mode 100644 index 000000000..8769e52bf --- /dev/null +++ b/internal/storage/sql_provider_schema_test.go @@ -0,0 +1,134 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShouldReturnErrOnTargetSameAsCurrent(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, true, 1, 1), + fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1)) + + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, false, 1, 1), + fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1)) + + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, false, 2, 2), + fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 2, 2)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, false, 1, 1), + fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1)) + + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, false, 1, 1), + fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1)) +} + +func TestShouldReturnErrOnUpMigrationTargetVersionLessTHanCurrent(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, true, 0, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion)) + + assert.NoError(t, + schemaMigrateChecks(providerPostgres, true, testLatestVersion, 0)) + + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, true, 0, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion)) + + assert.NoError(t, + schemaMigrateChecks(providerSQLite, true, testLatestVersion, 0)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, true, 0, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion)) + + assert.NoError(t, + schemaMigrateChecks(providerMySQL, true, testLatestVersion, 0)) +} + +func TestMigrationUpShouldReturnErrOnAlreadyLatest(t *testing.T) { + assert.Equal(t, + ErrSchemaAlreadyUpToDate, + schemaMigrateChecks(providerPostgres, true, SchemaLatest, testLatestVersion)) + + assert.Equal(t, + ErrSchemaAlreadyUpToDate, + schemaMigrateChecks(providerMySQL, true, SchemaLatest, testLatestVersion)) + + assert.Equal(t, + ErrSchemaAlreadyUpToDate, + schemaMigrateChecks(providerSQLite, true, SchemaLatest, testLatestVersion)) +} + +func TestShouldReturnErrOnVersionDoesntExits(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, true, SchemaLatest-1, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, true, SchemaLatest-1, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion)) + + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, true, SchemaLatest-1, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion)) +} + +func TestMigrationDownShouldReturnErrOnTargetLessThanPre1(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, false, -4, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -4)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, false, -2, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2)) + + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, false, -2, testLatestVersion), + fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2)) + + assert.NoError(t, + schemaMigrateChecks(providerPostgres, false, -1, testLatestVersion)) +} + +func TestMigrationDownShouldReturnErrOnTargetVersionGreaterThanCurrent(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, false, testLatestVersion, 0), + fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, false, testLatestVersion, 0), + fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0)) + + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, false, testLatestVersion, 0), + fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0)) +} + +func TestShouldReturnErrWhenCurrentIsGreaterThanLatest(t *testing.T) { + assert.EqualError(t, + schemaMigrateChecks(providerPostgres, true, SchemaLatest-4, SchemaLatest-5), + fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion)) + + assert.EqualError(t, + schemaMigrateChecks(providerMySQL, true, SchemaLatest-4, SchemaLatest-5), + fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion)) + + assert.EqualError(t, + schemaMigrateChecks(providerSQLite, true, SchemaLatest-4, SchemaLatest-5), + fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion)) +} + +func TestSchemaVersionToString(t *testing.T) { + assert.Equal(t, "unknown", SchemaVersionToString(-2)) + assert.Equal(t, "pre1", SchemaVersionToString(-1)) + assert.Equal(t, "N/A", SchemaVersionToString(0)) + assert.Equal(t, "1", SchemaVersionToString(1)) + assert.Equal(t, "2", SchemaVersionToString(2)) +} diff --git a/internal/storage/sql_provider_test.go b/internal/storage/sql_provider_test.go deleted file mode 100644 index bf978571a..000000000 --- a/internal/storage/sql_provider_test.go +++ /dev/null @@ -1,400 +0,0 @@ -package storage - -import ( - "database/sql/driver" - "encoding/base64" - "fmt" - "sort" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/authelia/authelia/v4/internal/authentication" - "github.com/authelia/authelia/v4/internal/models" -) - -const currentSchemaMockSchemaVersion = "1" - -func TestSQLInitializeDatabase(t *testing.T) { - provider, mock := NewSQLMockProvider() - - rows := sqlmock.NewRows([]string{"name"}) - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(rows) - - mock.ExpectBegin() - - keys := make([]string, 0, len(sqlUpgradeCreateTableStatements[1])) - for k := range sqlUpgradeCreateTableStatements[1] { - keys = append(keys, k) - } - - sort.Strings(keys) - - for _, table := range keys { - mock.ExpectExec( - fmt.Sprintf("CREATE TABLE %s .*", table)). - WillReturnResult(sqlmock.NewResult(0, 0)) - } - - mock.ExpectExec( - fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec( - fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)). - WithArgs("schema", "version", "1"). - WillReturnResult(sqlmock.NewResult(1, 1)) - - mock.ExpectCommit() - - err := provider.initialize(provider.db) - assert.NoError(t, err) -} - -func TestSQLUpgradeDatabase(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName)) - - mock.ExpectBegin() - - mock.ExpectExec( - fmt.Sprintf("CREATE TABLE %s .*", configTableName)). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec( - fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec( - fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)). - WithArgs("schema", "version", "1"). - WillReturnResult(sqlmock.NewResult(1, 1)) - - mock.ExpectCommit() - - err := provider.initialize(provider.db) - assert.NoError(t, err) -} - -func TestSQLProviderMethodsAuthenticationLogs(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName). - AddRow(configTableName)) - - args := []driver.Value{"schema", "version"} - mock.ExpectQuery( - fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"value"}). - AddRow("1")) - - err := provider.initialize(provider.db) - assert.NoError(t, err) - - attempts := []models.AuthenticationAttempt{ - {Username: unitTestUser, Successful: true, Time: time.Unix(1577880001, 0)}, - {Username: unitTestUser, Successful: true, Time: time.Unix(1577880002, 0)}, - {Username: unitTestUser, Successful: false, Time: time.Unix(1577880003, 0)}, - } - - rows := sqlmock.NewRows([]string{"successful", "time"}) - - for id, attempt := range attempts { - args = []driver.Value{attempt.Username, attempt.Successful, attempt.Time.Unix()} - mock.ExpectExec( - fmt.Sprintf("INSERT INTO %s \\(username, successful, time\\) VALUES \\(\\?, \\?, \\?\\)", authenticationLogsTableName)). - WithArgs(args...). - WillReturnResult(sqlmock.NewResult(int64(id), 1)) - - err := provider.AppendAuthenticationLog(attempt) - assert.NoError(t, err) - rows.AddRow(attempt.Successful, attempt.Time.Unix()) - } - - args = []driver.Value{1577880000, unitTestUser} - mock.ExpectQuery( - fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)). - WithArgs(args...). - WillReturnRows(rows) - - after := time.Unix(1577880000, 0) - results, err := provider.LoadLatestAuthenticationLogs(unitTestUser, after) - assert.NoError(t, err) - require.Len(t, results, 3) - assert.Equal(t, unitTestUser, results[0].Username) - assert.Equal(t, true, results[0].Successful) - assert.Equal(t, time.Unix(1577880001, 0), results[0].Time) - assert.Equal(t, unitTestUser, results[1].Username) - assert.Equal(t, true, results[1].Successful) - assert.Equal(t, time.Unix(1577880002, 0), results[1].Time) - assert.Equal(t, unitTestUser, results[2].Username) - assert.Equal(t, false, results[2].Successful) - assert.Equal(t, time.Unix(1577880003, 0), results[2].Time) - - // Test Blank Rows. - mock.ExpectQuery( - fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"successful", "time"})) - - results, err = provider.LoadLatestAuthenticationLogs(unitTestUser, after) - assert.NoError(t, err) - assert.Len(t, results, 0) -} - -func TestSQLProviderMethodsPreferred(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName). - AddRow(configTableName)) - - args := []driver.Value{"schema", "version"} - mock.ExpectQuery( - fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"value"}). - AddRow(currentSchemaMockSchemaVersion)) - - err := provider.initialize(provider.db) - assert.NoError(t, err) - - mock.ExpectExec( - fmt.Sprintf("REPLACE INTO %s \\(username, second_factor_method\\) VALUES \\(\\?, \\?\\)", userPreferencesTableName)). - WithArgs(unitTestUser, authentication.TOTP). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err = provider.SavePreferred2FAMethod(unitTestUser, authentication.TOTP) - assert.NoError(t, err) - - mock.ExpectQuery( - fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)). - WithArgs(unitTestUser). - WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}).AddRow(authentication.TOTP)) - - method, err := provider.LoadPreferred2FAMethod(unitTestUser) - assert.NoError(t, err) - assert.Equal(t, authentication.TOTP, method) - - // Test Blank Rows. - mock.ExpectQuery( - fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)). - WithArgs(unitTestUser). - WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"})) - - method, err = provider.LoadPreferred2FAMethod(unitTestUser) - assert.NoError(t, err) - assert.Equal(t, "", method) -} - -func TestSQLProviderMethodsTOTP(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName). - AddRow(configTableName)) - - args := []driver.Value{"schema", "version"} - mock.ExpectQuery( - fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"value"}). - AddRow(currentSchemaMockSchemaVersion)) - - err := provider.initialize(provider.db) - assert.NoError(t, err) - - pretendSecret := "abc123" - args = []driver.Value{unitTestUser, pretendSecret} - mock.ExpectExec( - fmt.Sprintf("REPLACE INTO %s \\(username, secret\\) VALUES \\(\\?, \\?\\)", totpSecretsTableName)). - WithArgs(args...). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err = provider.SaveTOTPSecret(unitTestUser, pretendSecret) - assert.NoError(t, err) - - args = []driver.Value{unitTestUser} - mock.ExpectQuery( - fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"secret"}).AddRow(pretendSecret)) - - secret, err := provider.LoadTOTPSecret(unitTestUser) - assert.NoError(t, err) - assert.Equal(t, pretendSecret, secret) - - mock.ExpectExec( - fmt.Sprintf("DELETE FROM %s WHERE username=\\?", totpSecretsTableName)). - WithArgs(unitTestUser). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err = provider.DeleteTOTPSecret(unitTestUser) - assert.NoError(t, err) - - mock.ExpectQuery( - fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"secret"})) - - // Test Blank Rows - secret, err = provider.LoadTOTPSecret(unitTestUser) - assert.EqualError(t, err, "no TOTP secret registered") - assert.Equal(t, "", secret) -} - -func TestSQLProviderMethodsU2F(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName). - AddRow(configTableName)) - - args := []driver.Value{"schema", "version"} - mock.ExpectQuery( - fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"value"}). - AddRow(currentSchemaMockSchemaVersion)) - - err := provider.initialize(provider.db) - assert.NoError(t, err) - - pretendKeyHandle := []byte("abc") - pretendPublicKey := []byte("123") - pretendKeyHandleB64 := base64.StdEncoding.EncodeToString(pretendKeyHandle) - pretendPublicKeyB64 := base64.StdEncoding.EncodeToString(pretendPublicKey) - - args = []driver.Value{unitTestUser, pretendKeyHandleB64, pretendPublicKeyB64} - mock.ExpectExec( - fmt.Sprintf("REPLACE INTO %s \\(username, keyHandle, publicKey\\) VALUES \\(\\?, \\?, \\?\\)", u2fDeviceHandlesTableName)). - WithArgs(args...). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err = provider.SaveU2FDeviceHandle(unitTestUser, pretendKeyHandle, pretendPublicKey) - assert.NoError(t, err) - - args = []driver.Value{unitTestUser} - mock.ExpectQuery( - fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}). - AddRow(pretendKeyHandleB64, pretendPublicKeyB64)) - - keyHandle, publicKey, err := provider.LoadU2FDeviceHandle(unitTestUser) - assert.NoError(t, err) - assert.Equal(t, pretendKeyHandle, keyHandle) - assert.Equal(t, pretendPublicKey, publicKey) - - // Test Blank Rows. - mock.ExpectQuery( - fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"})) - - keyHandle, publicKey, err = provider.LoadU2FDeviceHandle(unitTestUser) - assert.EqualError(t, err, "no U2F device handle found") - assert.Equal(t, []byte(nil), keyHandle) - assert.Equal(t, []byte(nil), publicKey) -} - -func TestSQLProviderMethodsIdentityVerificationTokens(t *testing.T) { - provider, mock := NewSQLMockProvider() - - mock.ExpectQuery( - "SELECT name FROM sqlite_master WHERE type='table'"). - WillReturnRows(sqlmock.NewRows([]string{"name"}). - AddRow(userPreferencesTableName). - AddRow(identityVerificationTokensTableName). - AddRow(totpSecretsTableName). - AddRow(u2fDeviceHandlesTableName). - AddRow(authenticationLogsTableName). - AddRow(configTableName)) - - args := []driver.Value{"schema", "version"} - mock.ExpectQuery( - fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)). - WithArgs(args...). - WillReturnRows(sqlmock.NewRows([]string{"value"}). - AddRow(currentSchemaMockSchemaVersion)) - - err := provider.initialize(provider.db) - assert.NoError(t, err) - - fakeIdentityVerificationToken := "abc" - - mock.ExpectExec( - fmt.Sprintf("INSERT INTO %s \\(token\\) VALUES \\(\\?\\)", identityVerificationTokensTableName)). - WithArgs(fakeIdentityVerificationToken). - WillReturnResult(sqlmock.NewResult(1, 1)) - - err = provider.SaveIdentityVerificationToken(fakeIdentityVerificationToken) - assert.NoError(t, err) - - mock.ExpectQuery( - fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)). - WithArgs(fakeIdentityVerificationToken). - WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}). - AddRow(true)) - - valid, err := provider.FindIdentityVerificationToken(fakeIdentityVerificationToken) - assert.NoError(t, err) - assert.True(t, valid) - - mock.ExpectExec( - fmt.Sprintf("DELETE FROM %s WHERE token=\\?", identityVerificationTokensTableName)). - WithArgs(fakeIdentityVerificationToken). - WillReturnResult(sqlmock.NewResult(0, 1)) - - err = provider.RemoveIdentityVerificationToken(fakeIdentityVerificationToken) - assert.NoError(t, err) - - mock.ExpectQuery( - fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)). - WithArgs(fakeIdentityVerificationToken). - WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}). - AddRow(false)) - - valid, err = provider.FindIdentityVerificationToken(fakeIdentityVerificationToken) - assert.NoError(t, err) - assert.False(t, valid) -} diff --git a/internal/storage/sqlite_provider.go b/internal/storage/sqlite_provider.go deleted file mode 100644 index 37331ec67..000000000 --- a/internal/storage/sqlite_provider.go +++ /dev/null @@ -1,58 +0,0 @@ -package storage - -import ( - "database/sql" - "fmt" - - _ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string. -) - -// SQLiteProvider is a SQLite3 provider. -type SQLiteProvider struct { - SQLProvider -} - -// NewSQLiteProvider constructs a SQLite provider. -func NewSQLiteProvider(path string) *SQLiteProvider { - provider := SQLiteProvider{ - SQLProvider{ - name: "sqlite", - - sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements, - sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements, - - sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName), - sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName), - - sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName), - sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName), - sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName), - - sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName), - sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName), - sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName), - - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName), - - sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName), - sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName), - - sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'", - - sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName), - sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName), - }, - } - - db, err := sql.Open("sqlite3", path) - if err != nil { - provider.log.Fatalf("Unable to create SQL database %s: %s", path, err) - } - - if err := provider.initialize(db); err != nil { - provider.log.Fatalf("Unable to initialize SQL database %s: %s", path, err) - } - - return &provider -} diff --git a/internal/storage/sqlmock_provider.go b/internal/storage/sqlmock_provider.go deleted file mode 100644 index 324b480a5..000000000 --- a/internal/storage/sqlmock_provider.go +++ /dev/null @@ -1,60 +0,0 @@ -package storage - -import ( - "fmt" - - "github.com/DATA-DOG/go-sqlmock" -) - -// SQLMockProvider is a SQLMock provider. -type SQLMockProvider struct { - SQLProvider -} - -// NewSQLMockProvider constructs a SQLMock provider. -func NewSQLMockProvider() (*SQLMockProvider, sqlmock.Sqlmock) { - provider := SQLMockProvider{ - SQLProvider{ - name: "sqlmock", - - sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements, - sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements, - - sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName), - sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName), - - sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName), - sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName), - sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName), - - sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName), - sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName), - sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName), - - sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName), - sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName), - - sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName), - sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName), - - sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'", - - sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName), - sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName), - }, - } - - db, mock, err := sqlmock.New() - - if err != nil { - provider.log.Fatalf("Unable to create SQL database: %s", err) - } - - provider.db = db - - /* - We do initialize in the tests rather than in the new up. - */ - - return &provider, mock -} diff --git a/internal/storage/types.go b/internal/storage/types.go index a148f5004..89a37ac00 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -1,18 +1,28 @@ package storage -import ( - "database/sql" - "strconv" -) - -// SchemaVersion is a simple int representation of the schema version. -type SchemaVersion int - -// ToString converts the schema version into a string and returns that converted value. -func (s SchemaVersion) ToString() string { - return strconv.Itoa(int(s)) +// SchemaMigration represents an intended migration. +type SchemaMigration struct { + Version int + Name string + Provider string + Up bool + Query string } -type transaction interface { - Exec(query string, args ...interface{}) (sql.Result, error) +// Before returns the version the schema should be at Before the migration is applied. +func (m SchemaMigration) Before() (before int) { + if m.Up { + return m.Version - 1 + } + + return m.Version +} + +// After returns the version the schema will be at After the migration is applied. +func (m SchemaMigration) After() (after int) { + if m.Up { + return m.Version + } + + return m.Version - 1 } diff --git a/internal/storage/upgrades.go b/internal/storage/upgrades.go deleted file mode 100644 index b42ece73c..000000000 --- a/internal/storage/upgrades.go +++ /dev/null @@ -1,76 +0,0 @@ -package storage - -import ( - "fmt" - "sort" - - "github.com/authelia/authelia/v4/internal/utils" -) - -func (p *SQLProvider) upgradeCreateTableStatements(tx transaction, statements map[string]string, existingTables []string) error { - keys := make([]string, 0, len(statements)) - for k := range statements { - keys = append(keys, k) - } - - sort.Strings(keys) - - for _, table := range keys { - if !utils.IsStringInSlice(table, existingTables) { - _, err := tx.Exec(fmt.Sprintf(statements[table], table)) - if err != nil { - return fmt.Errorf("unable to create table %s: %v", table, err) - } - } - } - - return nil -} - -func (p *SQLProvider) upgradeRunMultipleStatements(tx transaction, statements []string) error { - for _, statement := range statements { - _, err := tx.Exec(statement) - if err != nil { - return err - } - } - - return nil -} - -// upgradeFinalize sets the schema version and logs a message, as well as any other future finalization tasks. -func (p *SQLProvider) upgradeFinalize(tx transaction, version SchemaVersion) error { - _, err := tx.Exec(p.sqlConfigSetValue, "schema", "version", version.ToString()) - if err != nil { - return err - } - - p.log.Debugf("%s%d", storageSchemaUpgradeMessage, version) - - return nil -} - -// upgradeSchemaToVersion001 upgrades the schema to version 1. -func (p *SQLProvider) upgradeSchemaToVersion001(tx transaction, tables []string) error { - version := SchemaVersion(1) - - err := p.upgradeCreateTableStatements(tx, p.sqlUpgradesCreateTableStatements[version], tables) - if err != nil { - return err - } - - // Skip mysql create index statements. It doesn't support CREATE INDEX IF NOT EXIST. May be able to work around this with an Index struct. - if p.name != "mysql" { - err = p.upgradeRunMultipleStatements(tx, p.sqlUpgradesCreateTableIndexesStatements[1]) - if err != nil { - return fmt.Errorf("unable to create index: %v", err) - } - } - - err = p.upgradeFinalize(tx, version) - if err != nil { - return err - } - - return nil -} diff --git a/internal/suites/suite_standalone_test.go b/internal/suites/suite_standalone_test.go index 6f2ce907c..f02402027 100644 --- a/internal/suites/suite_standalone_test.go +++ b/internal/suites/suite_standalone_test.go @@ -122,7 +122,9 @@ func (s *StandaloneWebDriverSuite) TestShouldCheckUserIsAskedToRegisterDevice() // Clean up any TOTP secret already in DB. provider := storage.NewSQLiteProvider("/tmp/db.sqlite3") - require.NoError(s.T(), provider.DeleteTOTPSecret(username)) + + require.NoError(s.T(), provider.StartupCheck()) + require.NoError(s.T(), provider.DeleteTOTPConfiguration(ctx, username)) // Login one factor. s.doLoginOneFactor(s.T(), s.Context(ctx), username, password, false, "")