authelia/internal/storage/sql_provider_test.go

401 lines
13 KiB
Go

package storage
import (
"database/sql/driver"
"encoding/base64"
"fmt"
"sort"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/models"
)
const currentSchemaMockSchemaVersion = "1"
func TestSQLInitializeDatabase(t *testing.T) {
provider, mock := NewSQLMockProvider()
rows := sqlmock.NewRows([]string{"name"})
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(rows)
mock.ExpectBegin()
keys := make([]string, 0, len(sqlUpgradeCreateTableStatements[1]))
for k := range sqlUpgradeCreateTableStatements[1] {
keys = append(keys, k)
}
sort.Strings(keys)
for _, table := range keys {
mock.ExpectExec(
fmt.Sprintf("CREATE TABLE %s .*", table)).
WillReturnResult(sqlmock.NewResult(0, 0))
}
mock.ExpectExec(
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
WithArgs("schema", "version", "1").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := provider.initialize(provider.db)
assert.NoError(t, err)
}
func TestSQLUpgradeDatabase(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName))
mock.ExpectBegin()
mock.ExpectExec(
fmt.Sprintf("CREATE TABLE %s .*", configTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
WithArgs("schema", "version", "1").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := provider.initialize(provider.db)
assert.NoError(t, err)
}
func TestSQLProviderMethodsAuthenticationLogs(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow("1"))
err := provider.initialize(provider.db)
assert.NoError(t, err)
attempts := []models.AuthenticationAttempt{
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880001, 0)},
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880002, 0)},
{Username: unitTestUser, Successful: false, Time: time.Unix(1577880003, 0)},
}
rows := sqlmock.NewRows([]string{"successful", "time"})
for id, attempt := range attempts {
args = []driver.Value{attempt.Username, attempt.Successful, attempt.Time.Unix()}
mock.ExpectExec(
fmt.Sprintf("INSERT INTO %s \\(username, successful, time\\) VALUES \\(\\?, \\?, \\?\\)", authenticationLogsTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(int64(id), 1))
err := provider.AppendAuthenticationLog(attempt)
assert.NoError(t, err)
rows.AddRow(attempt.Successful, attempt.Time.Unix())
}
args = []driver.Value{1577880000, unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
WithArgs(args...).
WillReturnRows(rows)
after := time.Unix(1577880000, 0)
results, err := provider.LoadLatestAuthenticationLogs(unitTestUser, after)
assert.NoError(t, err)
require.Len(t, results, 3)
assert.Equal(t, unitTestUser, results[0].Username)
assert.Equal(t, true, results[0].Successful)
assert.Equal(t, time.Unix(1577880001, 0), results[0].Time)
assert.Equal(t, unitTestUser, results[1].Username)
assert.Equal(t, true, results[1].Successful)
assert.Equal(t, time.Unix(1577880002, 0), results[1].Time)
assert.Equal(t, unitTestUser, results[2].Username)
assert.Equal(t, false, results[2].Successful)
assert.Equal(t, time.Unix(1577880003, 0), results[2].Time)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"successful", "time"}))
results, err = provider.LoadLatestAuthenticationLogs(unitTestUser, after)
assert.NoError(t, err)
assert.Len(t, results, 0)
}
func TestSQLProviderMethodsPreferred(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, second_factor_method\\) VALUES \\(\\?, \\?\\)", userPreferencesTableName)).
WithArgs(unitTestUser, authentication.TOTP).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SavePreferred2FAMethod(unitTestUser, authentication.TOTP)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
WithArgs(unitTestUser).
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}).AddRow(authentication.TOTP))
method, err := provider.LoadPreferred2FAMethod(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, authentication.TOTP, method)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
WithArgs(unitTestUser).
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}))
method, err = provider.LoadPreferred2FAMethod(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, "", method)
}
func TestSQLProviderMethodsTOTP(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
pretendSecret := "abc123"
args = []driver.Value{unitTestUser, pretendSecret}
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, secret\\) VALUES \\(\\?, \\?\\)", totpSecretsTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SaveTOTPSecret(unitTestUser, pretendSecret)
assert.NoError(t, err)
args = []driver.Value{unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"secret"}).AddRow(pretendSecret))
secret, err := provider.LoadTOTPSecret(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, pretendSecret, secret)
mock.ExpectExec(
fmt.Sprintf("DELETE FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(unitTestUser).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.DeleteTOTPSecret(unitTestUser)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"secret"}))
// Test Blank Rows
secret, err = provider.LoadTOTPSecret(unitTestUser)
assert.EqualError(t, err, "no TOTP secret registered")
assert.Equal(t, "", secret)
}
func TestSQLProviderMethodsU2F(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
pretendKeyHandle := []byte("abc")
pretendPublicKey := []byte("123")
pretendKeyHandleB64 := base64.StdEncoding.EncodeToString(pretendKeyHandle)
pretendPublicKeyB64 := base64.StdEncoding.EncodeToString(pretendPublicKey)
args = []driver.Value{unitTestUser, pretendKeyHandleB64, pretendPublicKeyB64}
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, keyHandle, publicKey\\) VALUES \\(\\?, \\?, \\?\\)", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SaveU2FDeviceHandle(unitTestUser, pretendKeyHandle, pretendPublicKey)
assert.NoError(t, err)
args = []driver.Value{unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}).
AddRow(pretendKeyHandleB64, pretendPublicKeyB64))
keyHandle, publicKey, err := provider.LoadU2FDeviceHandle(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, pretendKeyHandle, keyHandle)
assert.Equal(t, pretendPublicKey, publicKey)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}))
keyHandle, publicKey, err = provider.LoadU2FDeviceHandle(unitTestUser)
assert.EqualError(t, err, "no U2F device handle found")
assert.Equal(t, []byte(nil), keyHandle)
assert.Equal(t, []byte(nil), publicKey)
}
func TestSQLProviderMethodsIdentityVerificationTokens(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
fakeIdentityVerificationToken := "abc"
mock.ExpectExec(
fmt.Sprintf("INSERT INTO %s \\(token\\) VALUES \\(\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnResult(sqlmock.NewResult(1, 1))
err = provider.SaveIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
AddRow(true))
valid, err := provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
assert.True(t, valid)
mock.ExpectExec(
fmt.Sprintf("DELETE FROM %s WHERE token=\\?", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.RemoveIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
AddRow(false))
valid, err = provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
assert.False(t, valid)
}