297 lines
6.9 KiB
Go
297 lines
6.9 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/authelia/authelia/v4/internal/models"
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
|
|
// by this command to encrypt the values again and update them using a transaction.
|
|
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) {
|
|
tx, err := p.db.Beginx()
|
|
if err != nil {
|
|
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
|
|
}
|
|
|
|
key := sha256.Sum256([]byte(encryptionKey))
|
|
|
|
if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = p.schemaEncryptionChangeKeyU2F(ctx, tx, key); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var configs []models.TOTPConfiguration
|
|
|
|
for page := 0; true; page++ {
|
|
if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
|
|
for _, config := range configs {
|
|
if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
|
|
if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
}
|
|
|
|
if len(configs) != 10 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaEncryptionChangeKeyU2F(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
|
var devices []models.U2FDevice
|
|
|
|
for page := 0; true; page++ {
|
|
if devices, err = p.LoadU2FDevices(ctx, 10, page); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
|
|
for _, device := range devices {
|
|
if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
|
|
if err = p.updateU2FDevicePublicKey(ctx, device); err != nil {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
|
}
|
|
|
|
return fmt.Errorf("rollback due to error: %w", err)
|
|
}
|
|
}
|
|
|
|
if len(devices) != 10 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
|
|
func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) {
|
|
version, err := p.SchemaVersion(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if version < 1 {
|
|
return ErrSchemaEncryptionVersionUnsupported
|
|
}
|
|
|
|
var errs []error
|
|
|
|
if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
|
|
errs = append(errs, ErrSchemaEncryptionInvalidKey)
|
|
}
|
|
|
|
if verbose {
|
|
if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
|
|
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
if len(errs) != 0 {
|
|
for i, e := range errs {
|
|
if i == 0 {
|
|
err = e
|
|
|
|
continue
|
|
}
|
|
|
|
err = fmt.Errorf("%w, %v", err, e)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
|
|
var (
|
|
config models.TOTPConfiguration
|
|
row int
|
|
invalid int
|
|
total int
|
|
)
|
|
|
|
pageSize := 10
|
|
|
|
var rows *sqlx.Rows
|
|
|
|
for page := 0; true; page++ {
|
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
|
}
|
|
|
|
row = 0
|
|
|
|
for rows.Next() {
|
|
total++
|
|
row++
|
|
|
|
if err = rows.StructScan(&config); err != nil {
|
|
_ = rows.Close()
|
|
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
|
|
}
|
|
|
|
if _, err = p.decrypt(config.Secret); err != nil {
|
|
invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
if row < pageSize {
|
|
break
|
|
}
|
|
}
|
|
|
|
if invalid != 0 {
|
|
return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
|
|
var (
|
|
device models.U2FDevice
|
|
row int
|
|
invalid int
|
|
total int
|
|
)
|
|
|
|
pageSize := 10
|
|
|
|
var rows *sqlx.Rows
|
|
|
|
for page := 0; true; page++ {
|
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, pageSize, pageSize*page); err != nil {
|
|
_ = rows.Close()
|
|
|
|
return fmt.Errorf("error selecting U2F devices: %w", err)
|
|
}
|
|
|
|
row = 0
|
|
|
|
for rows.Next() {
|
|
total++
|
|
row++
|
|
|
|
if err = rows.StructScan(&device); err != nil {
|
|
_ = rows.Close()
|
|
return fmt.Errorf("error scanning U2F device to struct: %w", err)
|
|
}
|
|
|
|
if _, err = p.decrypt(device.PublicKey); err != nil {
|
|
invalid++
|
|
}
|
|
}
|
|
|
|
_ = rows.Close()
|
|
|
|
if row < pageSize {
|
|
break
|
|
}
|
|
}
|
|
|
|
if invalid != 0 {
|
|
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
|
return utils.Encrypt(clearText, &p.key)
|
|
}
|
|
|
|
func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
|
return utils.Decrypt(cipherText, &p.key)
|
|
}
|
|
|
|
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
|
|
var encryptedValue []byte
|
|
|
|
err = p.db.GetContext(ctx, &encryptedValue, p.sqlSelectEncryptionValue, name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return p.decrypt(encryptedValue)
|
|
}
|
|
|
|
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) {
|
|
valueClearText := uuid.New()
|
|
|
|
value, err := utils.Encrypt([]byte(valueClearText.String()), key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if e != nil {
|
|
_, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
|
|
} else {
|
|
_, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
|
|
}
|
|
|
|
return err
|
|
}
|