package storage import ( "context" "crypto/sha256" "database/sql" "errors" "fmt" "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/model" ) // NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceName string) (provider SQLProvider) { db, err := sqlx.Open(driverName, dataSourceName) provider = SQLProvider{ db: db, key: sha256.Sum256([]byte(config.Storage.EncryptionKey)), name: name, driverName: driverName, config: config, errOpen: err, log: logging.Logger(), sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs), sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs), sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification), sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification), sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification), sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations), sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations), sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations), sqlSelectTOTPConfigs: fmt.Sprintf(queryFmtSelectTOTPConfigurations, tableTOTPConfigurations), sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations), sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations), sqlUpdateTOTPConfigRecordSignIn: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignIn, tableTOTPConfigurations), sqlUpdateTOTPConfigRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignInByUsername, tableTOTPConfigurations), sqlUpsertWebauthnDevice: fmt.Sprintf(queryFmtUpsertWebauthnDevice, tableWebauthnDevices), sqlSelectWebauthnDevices: fmt.Sprintf(queryFmtSelectWebauthnDevices, tableWebauthnDevices), sqlSelectWebauthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebauthnDevicesByUsername, tableWebauthnDevices), sqlUpdateWebauthnDevicePublicKey: fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices), sqlUpdateWebauthnDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername, tableWebauthnDevices), sqlUpdateWebauthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignIn, tableWebauthnDevices), sqlUpdateWebauthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignInByUsername, tableWebauthnDevices), sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices), sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices), sqlSelectDuoDevice: fmt.Sprintf(queryFmtSelectDuoDevice, tableDuoDevices), sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences), sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences), sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier), sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier), sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession), sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession), sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession), sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession), sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession), sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession), sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession), sqlRevokeOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession), sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession), sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession), sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession), sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession), sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession), sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession), sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession), sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession), sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession), sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession), sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession), sqlRevokeOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession), sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), sqlInsertOAuth2ConsentSession: fmt.Sprintf(queryFmtInsertOAuth2ConsentSession, tableOAuth2ConsentSession), sqlUpdateOAuth2ConsentSessionResponse: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionResponse, tableOAuth2ConsentSession), sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession), sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession), sqlSelectOAuth2ConsentSessionsPreConfigured: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionsPreConfigured, tableOAuth2ConsentSession), sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), sqlUpsertEncryptionValue: fmt.Sprintf(queryFmtUpsertEncryptionValue, tableEncryption), sqlSelectEncryptionValue: fmt.Sprintf(queryFmtSelectEncryptionValue, tableEncryption), sqlFmtRenameTable: queryFmtRenameTable, } return provider } // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { db *sqlx.DB key [32]byte name string driverName string schema string config *schema.Configuration errOpen error log *logrus.Logger // Table: authentication_logs. sqlInsertAuthenticationAttempt string sqlSelectAuthenticationAttemptsByUsername string // Table: identity_verification. sqlInsertIdentityVerification string sqlConsumeIdentityVerification string sqlSelectIdentityVerification string // Table: totp_configurations. sqlUpsertTOTPConfig string sqlDeleteTOTPConfig string sqlSelectTOTPConfig string sqlSelectTOTPConfigs string sqlUpdateTOTPConfigSecret string sqlUpdateTOTPConfigSecretByUsername string sqlUpdateTOTPConfigRecordSignIn string sqlUpdateTOTPConfigRecordSignInByUsername string // Table: webauthn_devices. sqlUpsertWebauthnDevice string sqlSelectWebauthnDevices string sqlSelectWebauthnDevicesByUsername string sqlUpdateWebauthnDevicePublicKey string sqlUpdateWebauthnDevicePublicKeyByUsername string sqlUpdateWebauthnDeviceRecordSignIn string sqlUpdateWebauthnDeviceRecordSignInByUsername string // Table: duo_devices. sqlUpsertDuoDevice string sqlDeleteDuoDevice string sqlSelectDuoDevice string // Table: user_preferences. sqlUpsertPreferred2FAMethod string sqlSelectPreferred2FAMethod string sqlSelectUserInfo string // Table: user_opaque_identifier. sqlInsertUserOpaqueIdentifier string sqlSelectUserOpaqueIdentifier string sqlSelectUserOpaqueIdentifiers string sqlSelectUserOpaqueIdentifierBySignature string // Table: migrations. sqlInsertMigration string sqlSelectMigrations string sqlSelectLatestMigration string // Table: encryption. sqlUpsertEncryptionValue string sqlSelectEncryptionValue string // Table: oauth2_authorization_code_session. sqlInsertOAuth2AuthorizeCodeSession string sqlSelectOAuth2AuthorizeCodeSession string sqlRevokeOAuth2AuthorizeCodeSession string sqlRevokeOAuth2AuthorizeCodeSessionByRequestID string sqlDeactivateOAuth2AuthorizeCodeSession string sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string // Table: oauth2_access_token_session. sqlInsertOAuth2AccessTokenSession string sqlSelectOAuth2AccessTokenSession string sqlRevokeOAuth2AccessTokenSession string sqlRevokeOAuth2AccessTokenSessionByRequestID string sqlDeactivateOAuth2AccessTokenSession string sqlDeactivateOAuth2AccessTokenSessionByRequestID string // Table: oauth2_refresh_token_session. sqlInsertOAuth2RefreshTokenSession string sqlSelectOAuth2RefreshTokenSession string sqlRevokeOAuth2RefreshTokenSession string sqlRevokeOAuth2RefreshTokenSessionByRequestID string sqlDeactivateOAuth2RefreshTokenSession string sqlDeactivateOAuth2RefreshTokenSessionByRequestID string // Table: oauth2_pkce_request_session. sqlInsertOAuth2PKCERequestSession string sqlSelectOAuth2PKCERequestSession string sqlRevokeOAuth2PKCERequestSession string sqlRevokeOAuth2PKCERequestSessionByRequestID string sqlDeactivateOAuth2PKCERequestSession string sqlDeactivateOAuth2PKCERequestSessionByRequestID string // Table: oauth2_openid_connect_session. sqlInsertOAuth2OpenIDConnectSession string sqlSelectOAuth2OpenIDConnectSession string sqlRevokeOAuth2OpenIDConnectSession string sqlRevokeOAuth2OpenIDConnectSessionByRequestID string sqlDeactivateOAuth2OpenIDConnectSession string sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string // Table: oauth2_consent_session. sqlInsertOAuth2ConsentSession string sqlUpdateOAuth2ConsentSessionResponse string sqlUpdateOAuth2ConsentSessionGranted string sqlSelectOAuth2ConsentSessionByChallengeID string sqlSelectOAuth2ConsentSessionsPreConfigured string sqlUpsertOAuth2BlacklistedJTI string sqlSelectOAuth2BlacklistedJTI string // Utility. sqlSelectExistingTables string sqlFmtRenameTable string } // Close the underlying database connection. func (p *SQLProvider) Close() (err error) { return p.db.Close() } // StartupCheck implements the provider startup check interface. func (p *SQLProvider) StartupCheck() (err error) { if p.errOpen != nil { return fmt.Errorf("error opening database: %w", p.errOpen) } // TODO: Decide if this is needed, or if it should be configurable. for i := 0; i < 19; i++ { if err = p.db.Ping(); err == nil { break } time.Sleep(time.Millisecond * 500) } if err != nil { return fmt.Errorf("error pinging database: %w", err) } p.log.Infof("Storage schema is being checked for updates") ctx := context.Background() if err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) { return err } err = p.SchemaMigrate(ctx, true, SchemaLatest) switch err { case ErrSchemaAlreadyUpToDate: p.log.Infof("Storage schema is already up to date") return nil case nil: return nil default: return fmt.Errorf("error during schema migrate: %w", err) } } // BeginTX begins a transaction. func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) { var tx *sql.Tx if tx, err = p.db.Begin(); err != nil { return nil, err } return context.WithValue(ctx, ctxKeyTransaction, tx), nil } // Commit performs a database commit. func (p *SQLProvider) Commit(ctx context.Context) (err error) { tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) if !ok { return errors.New("could not retrieve tx") } return tx.Commit() } // Rollback performs a database rollback. func (p *SQLProvider) Rollback(ctx context.Context) (err error) { tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) if !ok { return errors.New("could not retrieve tx") } return tx.Rollback() } // SaveUserOpaqueIdentifier saves a new opaque user identifier to the database. func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil { return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err) } return nil } // LoadUserOpaqueIdentifier selects an opaque user identifier from the database. func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) { opaqueID = &model.UserOpaqueIdentifier{} if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, nil default: return nil, err } } return opaqueID, nil } // LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database. func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) { var rows *sqlx.Rows if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil { return nil, fmt.Errorf("error selecting user opaque identifiers: %w", err) } var opaqueID *model.UserOpaqueIdentifier for rows.Next() { opaqueID = &model.UserOpaqueIdentifier{} if err = rows.StructScan(opaqueID); err != nil { return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err) } opaqueIDs = append(opaqueIDs, *opaqueID) } return opaqueIDs, nil } // LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { opaqueID = &model.UserOpaqueIdentifier{} if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil { switch { case errors.Is(err, sql.ErrNoRows): return nil, nil default: return nil, err } } return opaqueID, nil } // SaveOAuth2ConsentSession inserts an OAuth2.0 consent. func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession, consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted, consent.RequestedAt, consent.RespondedAt, consent.ExpiresAt, consent.Form, consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience); err != nil { return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.String(), err) } return nil } // SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent with the consent response. func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) { _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.ExpiresAt, consent.GrantedScopes, consent.GrantedAudience, consent.ID) if err != nil { return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject, err) } return nil } // SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint. func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil { return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err) } return nil } // LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID. func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) { consent = &model.OAuth2ConsentSession{} if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil { return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err) } return consent, nil } // LoadOAuth2ConsentSessionsPreConfigured returns an OAuth2.0 consents that are pre-configured given the consent signature. func (p *SQLProvider) LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) { var r *sqlx.Rows if r, err = p.db.QueryxContext(ctx, p.sqlSelectOAuth2ConsentSessionsPreConfigured, clientID, subject); err != nil { if errors.Is(err, sql.ErrNoRows) { return &ConsentSessionRows{}, nil } return &ConsentSessionRows{}, fmt.Errorf("error selecting oauth2 consent session by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err) } return &ConsentSessionRows{rows: r}, nil } // SaveOAuth2Session saves a OAuth2Session to the database. func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlInsertOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlInsertOAuth2AccessTokenSession case OAuth2SessionTypeRefreshToken: query = p.sqlInsertOAuth2RefreshTokenSession case OAuth2SessionTypePKCEChallenge: query = p.sqlInsertOAuth2PKCERequestSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlInsertOAuth2OpenIDConnectSession default: return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType) } if session.Session, err = p.encrypt(session.Session); err != nil { return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) } _, err = p.db.ExecContext(ctx, query, session.ChallengeID, session.RequestID, session.ClientID, session.Signature, session.Subject, session.RequestedAt, session.RequestedScopes, session.GrantedScopes, session.RequestedAudience, session.GrantedAudience, session.Active, session.Revoked, session.Form, session.Session) if err != nil { return fmt.Errorf("error inserting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) } return nil } // RevokeOAuth2Session marks a OAuth2Session as revoked in the database. func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlRevokeOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlRevokeOAuth2AccessTokenSession case OAuth2SessionTypeRefreshToken: query = p.sqlRevokeOAuth2RefreshTokenSession case OAuth2SessionTypePKCEChallenge: query = p.sqlRevokeOAuth2PKCERequestSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlRevokeOAuth2OpenIDConnectSession default: return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) } if _, err = p.db.ExecContext(ctx, query, signature); err != nil { return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err) } return nil } // RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database. func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID case OAuth2SessionTypeAccessToken: query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID case OAuth2SessionTypeRefreshToken: query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID case OAuth2SessionTypePKCEChallenge: query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID case OAuth2SessionTypeOpenIDConnect: query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID default: return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) } if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err) } return nil } // DeactivateOAuth2Session marks a OAuth2Session as inactive in the database. func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlDeactivateOAuth2AccessTokenSession case OAuth2SessionTypeRefreshToken: query = p.sqlDeactivateOAuth2RefreshTokenSession case OAuth2SessionTypePKCEChallenge: query = p.sqlDeactivateOAuth2PKCERequestSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlDeactivateOAuth2OpenIDConnectSession default: return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) } if _, err = p.db.ExecContext(ctx, query, signature); err != nil { return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err) } return nil } // DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database. func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlDeactivateOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID case OAuth2SessionTypeRefreshToken: query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID case OAuth2SessionTypePKCEChallenge: query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID case OAuth2SessionTypeOpenIDConnect: query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID default: return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) } if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { return fmt.Errorf("error deactivating oauth2 %s session with request id '%s': %w", sessionType, requestID, err) } return nil } // LoadOAuth2Session saves a OAuth2Session from the database. func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) { var query string switch sessionType { case OAuth2SessionTypeAuthorizeCode: query = p.sqlSelectOAuth2AuthorizeCodeSession case OAuth2SessionTypeAccessToken: query = p.sqlSelectOAuth2AccessTokenSession case OAuth2SessionTypeRefreshToken: query = p.sqlSelectOAuth2RefreshTokenSession case OAuth2SessionTypePKCEChallenge: query = p.sqlSelectOAuth2PKCERequestSession case OAuth2SessionTypeOpenIDConnect: query = p.sqlSelectOAuth2OpenIDConnectSession default: return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType) } session = &model.OAuth2Session{} if err = p.db.GetContext(ctx, session, query, signature); err != nil { return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err) } if session.Session, err = p.decrypt(session.Session); err != nil { return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject, session.RequestID, err) } return session, nil } // SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database. func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) } return nil } // LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database. func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) { blacklistedJTI = &model.OAuth2BlacklistedJTI{} if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil { return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) } return blacklistedJTI, nil } // SavePreferred2FAMethod save the preferred method for 2FA to the database. func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil { return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err) } return nil } // LoadPreferred2FAMethod load the preferred method for 2FA from the database. func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) { err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username) switch { case err == nil: return method, nil case errors.Is(err, sql.ErrNoRows): return "", sql.ErrNoRows default: return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) } } // LoadUserInfo loads the model.UserInfo from the database. func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) { err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) switch { case err == nil, errors.Is(err, sql.ErrNoRows): return info, nil default: return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) } } // SaveIdentityVerification save an identity verification record to the database. func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt, verification.Username, verification.Action); err != nil { return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err) } return nil } // ConsumeIdentityVerification marks an identity verification record in the database as consumed. func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil { return fmt.Errorf("error updating identity verification: %w", err) } return nil } // FindIdentityVerification checks if an identity verification record is in the database and active. func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) { verification := model.IdentityVerification{} if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil } return false, fmt.Errorf("error selecting identity verification exists: %w", err) } switch { case verification.Consumed != nil: return false, fmt.Errorf("the token has already been consumed") case verification.ExpiresAt.Before(time.Now()): return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt)) default: return true, nil } } // SaveTOTPConfiguration save a TOTP configuration of a given user in the database. func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) { if config.Secret, err = p.encrypt(config.Secret); err != nil { return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err) } if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, config.CreatedAt, config.LastUsedAt, config.Username, config.Issuer, config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err) } return nil } // UpdateTOTPConfigurationSignIn updates a registered Webauthn devices sign in information. func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt *time.Time) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil { return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err) } return nil } // DeleteTOTPConfiguration delete a TOTP configuration from the database given a username. func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil { return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err) } return nil } // LoadTOTPConfiguration load a TOTP configuration given a username from the database. func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) { config = &model.TOTPConfiguration{} if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoTOTPConfiguration } return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err) } if config.Secret, err = p.decrypt(config.Secret); err != nil { return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err) } return config, nil } // LoadTOTPConfigurations load a set of TOTP configurations. func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) { configs = make([]model.TOTPConfiguration, 0, limit) if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("error selecting TOTP configurations: %w", err) } for i, c := range configs { if configs[i].Secret, err = p.decrypt(c.Secret); err != nil { return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err) } } return configs, nil } func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config model.TOTPConfiguration) (err error) { switch config.ID { case 0: _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username) default: _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecret, config.Secret, config.ID) } if err != nil { return fmt.Errorf("error updating TOTP configuration secret for user '%s': %w", config.Username, err) } return nil } // SaveWebauthnDevice saves a registered Webauthn device. func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) { if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil { return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err) } if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice, device.CreatedAt, device.LastUsedAt, device.RPID, device.Username, device.Description, device.KID, device.PublicKey, device.AttestationType, device.Transport, device.AAGUID, device.SignCount, device.CloneWarning, ); err != nil { return fmt.Errorf("error upserting Webauthn device for user '%s' kid '%x': %w", device.Username, device.KID, err) } return nil } // UpdateWebauthnDeviceSignIn updates a registered Webauthn devices sign in information. func (p *SQLProvider) UpdateWebauthnDeviceSignIn(ctx context.Context, id int, rpid string, lastUsedAt *time.Time, signCount uint32, cloneWarning bool) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDeviceRecordSignIn, rpid, lastUsedAt, signCount, cloneWarning, id); err != nil { return fmt.Errorf("error updating Webauthn signin metadata for id '%x': %w", id, err) } return nil } // LoadWebauthnDevices loads Webauthn device registrations. func (p *SQLProvider) LoadWebauthnDevices(ctx context.Context, limit, page int) (devices []model.WebauthnDevice, err error) { devices = make([]model.WebauthnDevice, 0, limit) if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevices, limit, limit*page); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("error selecting Webauthn devices: %w", err) } for i, device := range devices { if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil { return nil, fmt.Errorf("error decrypting Webauthn public key for user '%s': %w", device.Username, err) } } return devices, nil } // LoadWebauthnDevicesByUsername loads all webauthn devices registration for a given username. func (p *SQLProvider) LoadWebauthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebauthnDevice, err error) { if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevicesByUsername, username); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoWebauthnDevice } return nil, fmt.Errorf("error selecting Webauthn devices for user '%s': %w", username, err) } for i, device := range devices { if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil { return nil, fmt.Errorf("error decrypting Webauthn public key for user '%s': %w", username, err) } } return devices, nil } func (p *SQLProvider) updateWebauthnDevicePublicKey(ctx context.Context, device model.WebauthnDevice) (err error) { switch device.ID { case 0: _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKeyByUsername, device.PublicKey, device.Username, device.KID) default: _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKey, device.PublicKey, device.ID) } if err != nil { return fmt.Errorf("error updating Webauthn public key for user '%s' kid '%x': %w", device.Username, device.KID, err) } return nil } // SavePreferredDuoDevice saves a Duo device. func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil { return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err) } return nil } // DeletePreferredDuoDevice deletes a Duo device of a given user. func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil { return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err) } return nil } // LoadPreferredDuoDevice loads a Duo device of a given user. func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) { device = &model.DuoDevice{} if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil { if err == sql.ErrNoRows { return nil, ErrNoDuoDevice } return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err) } return device, nil } // AppendAuthenticationLog append a mark to the authentication log. func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model.AuthenticationAttempt) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Banned, attempt.Username, attempt.Type, attempt.RemoteIP, attempt.RequestURI, attempt.RequestMethod); err != nil { return fmt.Errorf("error inserting authentication attempt for user '%s': %w", attempt.Username, err) } return nil } // LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log. func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) { attempts = make([]model.AuthenticationAttempt, 0, limit) if err = p.db.SelectContext(ctx, &attempts, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoAuthenticationLogs } return nil, fmt.Errorf("error selecting authentication logs for user '%s': %w", username, err) } return attempts, nil }