refactor(random): add random provider (#4712)

This adds a random provider which makes usage of random operations mockable, and may allow us in the future to swap out the Cryptographical CPU random generator with dedicated hardware random generators.
pull/4714/head^2
James Elliott 2023-01-07 11:19:41 +11:00 committed by GitHub
parent f223975e79
commit fc5ea5b485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 643 additions and 156 deletions

2
go.mod
View File

@ -40,7 +40,6 @@ require (
github.com/trustelem/zxcvbn v1.0.1 github.com/trustelem/zxcvbn v1.0.1
github.com/valyala/fasthttp v1.43.0 github.com/valyala/fasthttp v1.43.0
github.com/wneessen/go-mail v0.3.7 github.com/wneessen/go-mail v0.3.7
golang.org/x/net v0.5.0
golang.org/x/sync v0.1.0 golang.org/x/sync v0.1.0
golang.org/x/term v0.4.0 golang.org/x/term v0.4.0
golang.org/x/text v0.6.0 golang.org/x/text v0.6.0
@ -110,6 +109,7 @@ require (
github.com/ysmood/leakless v0.8.0 // indirect github.com/ysmood/leakless v0.8.0 // indirect
golang.org/x/crypto v0.1.0 // indirect golang.org/x/crypto v0.1.0 // indirect
golang.org/x/mod v0.6.0 // indirect golang.org/x/mod v0.6.0 // indirect
golang.org/x/net v0.5.0 // indirect
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b // indirect golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b // indirect
golang.org/x/sys v0.4.0 // indirect golang.org/x/sys v0.4.0 // indirect
golang.org/x/tools v0.2.0 // indirect golang.org/x/tools v0.2.0 // indirect

View File

@ -1,6 +1,7 @@
package commands package commands
import ( import (
"context"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"os" "os"
@ -8,7 +9,6 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authentication"
@ -22,6 +22,7 @@ import (
"github.com/authelia/authelia/v4/internal/notification" "github.com/authelia/authelia/v4/internal/notification"
"github.com/authelia/authelia/v4/internal/ntp" "github.com/authelia/authelia/v4/internal/ntp"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/regulation"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/storage"
@ -43,6 +44,9 @@ func NewCmdCtx() *CmdCtx {
cancel: cancel, cancel: cancel,
group: group, group: group,
log: logging.Logger(), log: logging.Logger(),
providers: middlewares.Providers{
Random: &random.Cryptographical{},
},
config: &schema.Configuration{}, config: &schema.Configuration{},
} }
} }
@ -139,48 +143,43 @@ func (ctx *CmdCtx) LoadProviders() (warns, errs []error) {
return warns, errs return warns, errs
} }
storage := getStorageProvider(ctx) ctx.providers.StorageProvider = getStorageProvider(ctx)
providers := middlewares.Providers{ ctx.providers.Authorizer = authorization.NewAuthorizer(ctx.config)
Authorizer: authorization.NewAuthorizer(ctx.config), ctx.providers.NTP = ntp.NewProvider(&ctx.config.NTP)
NTP: ntp.NewProvider(&ctx.config.NTP), ctx.providers.PasswordPolicy = middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy)
PasswordPolicy: middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy), ctx.providers.Regulator = regulation.NewRegulator(ctx.config.Regulation, ctx.providers.StorageProvider, utils.RealClock{})
Regulator: regulation.NewRegulator(ctx.config.Regulation, storage, utils.RealClock{}), ctx.providers.SessionProvider = session.NewProvider(ctx.config.Session, ctx.trusted)
SessionProvider: session.NewProvider(ctx.config.Session, ctx.trusted), ctx.providers.TOTP = totp.NewTimeBasedProvider(ctx.config.TOTP)
StorageProvider: storage,
TOTP: totp.NewTimeBasedProvider(ctx.config.TOTP),
}
var err error var err error
switch { switch {
case ctx.config.AuthenticationBackend.File != nil: case ctx.config.AuthenticationBackend.File != nil:
providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File) ctx.providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File)
case ctx.config.AuthenticationBackend.LDAP != nil: case ctx.config.AuthenticationBackend.LDAP != nil:
providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted) ctx.providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted)
} }
if providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil { if ctx.providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
switch { switch {
case ctx.config.Notifier.SMTP != nil: case ctx.config.Notifier.SMTP != nil:
providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted) ctx.providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted)
case ctx.config.Notifier.FileSystem != nil: case ctx.config.Notifier.FileSystem != nil:
providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem) ctx.providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem)
} }
if providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, storage); err != nil { if ctx.providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, ctx.providers.StorageProvider); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
if ctx.config.Telemetry.Metrics.Enabled { if ctx.config.Telemetry.Metrics.Enabled {
providers.Metrics = metrics.NewPrometheus() ctx.providers.Metrics = metrics.NewPrometheus()
} }
ctx.providers = providers
return warns, errs return warns, errs
} }

View File

@ -3,7 +3,6 @@ package commands
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -262,7 +261,7 @@ func (ctx *CmdCtx) CryptoGenerateRunE(cmd *cobra.Command, args []string) (err er
privateKey any privateKey any
) )
if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil { if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil {
return err return err
} }
@ -279,7 +278,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string)
privateKey any privateKey any
) )
if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil { if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil {
return err return err
} }
@ -326,7 +325,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string)
b.Reset() b.Reset()
if csr, err = x509.CreateCertificateRequest(rand.Reader, template, privateKey); err != nil { if csr, err = x509.CreateCertificateRequest(ctx.providers.Random, template, privateKey); err != nil {
return fmt.Errorf("failed to create certificate request: %w", err) return fmt.Errorf("failed to create certificate request: %w", err)
} }
@ -366,7 +365,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string,
signatureKey = caPrivateKey signatureKey = caPrivateKey
} }
if template, err = cryptoGetCertificateFromCmd(cmd); err != nil { if template, err = ctx.cryptoGetCertificateFromCmd(cmd); err != nil {
return err return err
} }
@ -423,7 +422,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string,
b.Reset() b.Reset()
if certificate, err = x509.CreateCertificate(rand.Reader, template, parent, publicKey, signatureKey); err != nil { if certificate, err = x509.CreateCertificate(ctx.providers.Random, template, parent, publicKey, signatureKey); err != nil {
return fmt.Errorf("failed to create certificate: %w", err) return fmt.Errorf("failed to create certificate: %w", err)
} }

View File

@ -4,7 +4,6 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
@ -130,7 +129,7 @@ func cryptoGetWritePathsFromCmd(cmd *cobra.Command) (privateKey, publicKey strin
return filepath.Join(dir, private), filepath.Join(dir, public), nil return filepath.Join(dir, private), filepath.Join(dir, public), nil
} }
func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) { func (ctx *CmdCtx) cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) {
switch cmd.Parent().Use { switch cmd.Parent().Use {
case cmdUseRSA: case cmdUseRSA:
var ( var (
@ -141,7 +140,7 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error)
return nil, err return nil, err
} }
if privateKey, err = rsa.GenerateKey(rand.Reader, bits); err != nil { if privateKey, err = rsa.GenerateKey(ctx.providers.Random, bits); err != nil {
return nil, fmt.Errorf("generating RSA private key resulted in an error: %w", err) return nil, fmt.Errorf("generating RSA private key resulted in an error: %w", err)
} }
case cmdUseECDSA: case cmdUseECDSA:
@ -158,11 +157,11 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error)
return nil, fmt.Errorf("invalid curve '%s' was specified: curve must be P224, P256, P384, or P521", curveStr) return nil, fmt.Errorf("invalid curve '%s' was specified: curve must be P224, P256, P384, or P521", curveStr)
} }
if privateKey, err = ecdsa.GenerateKey(curve, rand.Reader); err != nil { if privateKey, err = ecdsa.GenerateKey(curve, ctx.providers.Random); err != nil {
return nil, fmt.Errorf("generating ECDSA private key resulted in an error: %w", err) return nil, fmt.Errorf("generating ECDSA private key resulted in an error: %w", err)
} }
case cmdUseEd25519: case cmdUseEd25519:
if _, privateKey, err = ed25519.GenerateKey(rand.Reader); err != nil { if _, privateKey, err = ed25519.GenerateKey(ctx.providers.Random); err != nil {
return nil, fmt.Errorf("generating Ed25519 private key resulted in an error: %w", err) return nil, fmt.Errorf("generating Ed25519 private key resulted in an error: %w", err)
} }
} }
@ -336,7 +335,7 @@ func cryptoGetSubjectFromCmd(cmd *cobra.Command) (subject *pkix.Name, err error)
}, nil }, nil
} }
func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) { func (ctx *CmdCtx) cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) {
var ( var (
ca bool ca bool
subject *pkix.Name subject *pkix.Name
@ -378,7 +377,7 @@ func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certific
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
if serialNumber, err = rand.Int(rand.Reader, serialNumberLimit); err != nil { if serialNumber, err = ctx.providers.Random.IntErr(serialNumberLimit); err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err) return nil, fmt.Errorf("failed to generate serial number: %w", err)
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/validator" "github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/totp" "github.com/authelia/authelia/v4/internal/totp"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
@ -983,7 +984,8 @@ func (ctx *CmdCtx) StorageUserTOTPExportPNGRunE(cmd *cobra.Command, _ []string)
} }
if dir == "" { if dir == "" {
dir = utils.RandomString(8, utils.CharSetAlphaNumeric) rand := &random.Cryptographical{}
dir = rand.StringCustom(8, random.CharSetAlphaNumeric)
} }
if _, err = os.Stat(dir); !os.IsNotExist(err) { if _, err = os.Stat(dir); !os.IsNotExist(err) {

View File

@ -14,7 +14,7 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/authelia/authelia/v4/internal/configuration" "github.com/authelia/authelia/v4/internal/configuration"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/random"
) )
func recoverErr(i any) error { func recoverErr(i any) error {
@ -77,29 +77,29 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
switch c { switch c {
case "ascii": case "ascii":
charset = utils.CharSetASCII charset = random.CharSetASCII
case "alphanumeric": case "alphanumeric":
charset = utils.CharSetAlphaNumeric charset = random.CharSetAlphaNumeric
case "alphanumeric-lower": case "alphanumeric-lower":
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric charset = random.CharSetAlphabeticLower + random.CharSetNumeric
case "alphanumeric-upper": case "alphanumeric-upper":
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric charset = random.CharSetAlphabeticUpper + random.CharSetNumeric
case "alphabetic": case "alphabetic":
charset = utils.CharSetAlphabetic charset = random.CharSetAlphabetic
case "alphabetic-lower": case "alphabetic-lower":
charset = utils.CharSetAlphabeticLower charset = random.CharSetAlphabeticLower
case "alphabetic-upper": case "alphabetic-upper":
charset = utils.CharSetAlphabeticUpper charset = random.CharSetAlphabeticUpper
case "numeric-hex": case "numeric-hex":
charset = utils.CharSetNumericHex charset = random.CharSetNumericHex
case "numeric": case "numeric":
charset = utils.CharSetNumeric charset = random.CharSetNumeric
case "rfc3986": case "rfc3986":
charset = utils.CharSetRFC3986Unreserved charset = random.CharSetRFC3986Unreserved
case "rfc3986-lower": case "rfc3986-lower":
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved charset = random.CharSetAlphabeticLower + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved
case "rfc3986-upper": case "rfc3986-upper":
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved charset = random.CharSetAlphabeticUpper + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved
default: default:
return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c) return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c)
} }
@ -109,7 +109,9 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
} }
} }
return utils.RandomString(n, charset), nil rand := &random.Cryptographical{}
return rand.StringCustom(n, charset), nil
} }
func termReadConfirmation(flags *pflag.FlagSet, name, prompt, confirmation string) (confirmed bool, err error) { func termReadConfirmation(flags *pflag.FlagSet, name, prompt, confirmation string) (confirmed bool, err error) {

View File

@ -13,6 +13,7 @@ import (
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
) )
@ -128,6 +129,7 @@ func TestShouldCallNextWithAutheliaCtx(t *testing.T) {
providers := middlewares.Providers{ providers := middlewares.Providers{
UserProvider: userProvider, UserProvider: userProvider,
SessionProvider: sessionProvider, SessionProvider: sessionProvider,
Random: random.NewMathematical(),
} }
nextCalled := false nextCalled := false

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/mocks"
@ -73,8 +74,8 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
defer mock.Close() defer mock.Close()
mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Configuration.JWTSecret = testJWTSecret
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http")
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host")
mock.StorageMock.EXPECT(). mock.StorageMock.EXPECT().
SaveIdentityVerification(mock.Ctx, gomock.Any()). SaveIdentityVerification(mock.Ctx, gomock.Any()).
@ -95,8 +96,8 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t) mock := mocks.NewMockAutheliaCtx(t)
mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Configuration.JWTSecret = testJWTSecret
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http")
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host")
mock.StorageMock.EXPECT(). mock.StorageMock.EXPECT().
SaveIdentityVerification(mock.Ctx, gomock.Any()). SaveIdentityVerification(mock.Ctx, gomock.Any()).

View File

@ -1,7 +1,6 @@
package middlewares package middlewares
import ( import (
"crypto/rand"
"math" "math"
"math/big" "math/big"
"sync" "sync"
@ -62,7 +61,7 @@ func movingAverageIteration(value time.Duration, history int, successful bool, c
} }
func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) { func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs)) randomDelayMs, err := ctx.Providers.Random.IntErr(big.NewInt(maxRandomMs))
if err != nil { if err != nil {
return float64(maxRandomMs) return float64(maxRandomMs)
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/random"
) )
func TestTimingAttackDelayAverages(t *testing.T) { func TestTimingAttackDelayAverages(t *testing.T) {
@ -45,7 +46,12 @@ func TestTimingAttackDelayCalculations(t *testing.T) {
avgExecDurationMs := 1000.0 avgExecDurationMs := 1000.0
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds()) expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
ctx := &AutheliaCtx{Logger: logging.Logger().WithFields(logrus.Fields{})} ctx := &AutheliaCtx{
Logger: logging.Logger().WithFields(logrus.Fields{}),
Providers: Providers{
Random: &random.Cryptographical{},
},
}
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
delay := calculateActualDelay(ctx, execDuration, avgExecDurationMs, 250, 85, false) delay := calculateActualDelay(ctx, execDuration, avgExecDurationMs, 250, 85, false)

View File

@ -11,6 +11,7 @@ import (
"github.com/authelia/authelia/v4/internal/notification" "github.com/authelia/authelia/v4/internal/notification"
"github.com/authelia/authelia/v4/internal/ntp" "github.com/authelia/authelia/v4/internal/ntp"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/regulation"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/storage"
@ -44,6 +45,7 @@ type Providers struct {
Templates *templates.Provider Templates *templates.Provider
TOTP totp.Provider TOTP totp.Provider
PasswordPolicy PasswordPolicyProvider PasswordPolicy PasswordPolicyProvider
Random random.Provider
} }
// RequestHandler represents an Authelia request handler. // RequestHandler represents an Authelia request handler.

View File

@ -16,6 +16,7 @@ import (
"github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/regulation"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/templates" "github.com/authelia/authelia/v4/internal/templates"
@ -34,6 +35,7 @@ type MockAutheliaCtx struct {
StorageMock *MockStorage StorageMock *MockStorage
NotifierMock *MockNotifier NotifierMock *MockNotifier
TOTPMock *MockTOTP TOTPMock *MockTOTP
RandomMock *MockRandom
UserSession *session.UserSession UserSession *session.UserSession
@ -98,6 +100,10 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx {
mockAuthelia.TOTPMock = NewMockTOTP(mockAuthelia.Ctrl) mockAuthelia.TOTPMock = NewMockTOTP(mockAuthelia.Ctrl)
providers.TOTP = mockAuthelia.TOTPMock providers.TOTP = mockAuthelia.TOTPMock
mockAuthelia.RandomMock = NewMockRandom(mockAuthelia.Ctrl)
providers.Random = random.NewMathematical()
var err error var err error
if providers.Templates, err = templates.New(templates.Config{}); err != nil { if providers.Templates, err = templates.New(templates.Config{}); err != nil {

View File

@ -8,3 +8,4 @@ package mocks
//go:generate mockgen -package mocks -destination totp.go -mock_names Provider=MockTOTP github.com/authelia/authelia/v4/internal/totp Provider //go:generate mockgen -package mocks -destination totp.go -mock_names Provider=MockTOTP github.com/authelia/authelia/v4/internal/totp Provider
//go:generate mockgen -package mocks -destination storage.go -mock_names Provider=MockStorage github.com/authelia/authelia/v4/internal/storage Provider //go:generate mockgen -package mocks -destination storage.go -mock_names Provider=MockStorage github.com/authelia/authelia/v4/internal/storage Provider
//go:generate mockgen -package mocks -destination duo_api.go -mock_names API=MockAPI github.com/authelia/authelia/v4/internal/duo API //go:generate mockgen -package mocks -destination duo_api.go -mock_names API=MockAPI github.com/authelia/authelia/v4/internal/duo API
//go:generate mockgen -package mocks -destination random.go -mock_names Provider=MockRandom github.com/authelia/authelia/v4/internal/random Provider

View File

@ -0,0 +1,195 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/authelia/authelia/v4/internal/random (interfaces: Provider)
// Package mocks is a generated GoMock package.
package mocks
import (
big "math/big"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockRandom is a mock of Provider interface.
type MockRandom struct {
ctrl *gomock.Controller
recorder *MockRandomMockRecorder
}
// MockRandomMockRecorder is the mock recorder for MockRandom.
type MockRandomMockRecorder struct {
mock *MockRandom
}
// NewMockRandom creates a new mock instance.
func NewMockRandom(ctrl *gomock.Controller) *MockRandom {
mock := &MockRandom{ctrl: ctrl}
mock.recorder = &MockRandomMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRandom) EXPECT() *MockRandomMockRecorder {
return m.recorder
}
// Bytes mocks base method.
func (m *MockRandom) Bytes() []byte {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Bytes")
ret0, _ := ret[0].([]byte)
return ret0
}
// Bytes indicates an expected call of Bytes.
func (mr *MockRandomMockRecorder) Bytes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockRandom)(nil).Bytes))
}
// BytesCustom mocks base method.
func (m *MockRandom) BytesCustom(arg0 int, arg1 []byte) []byte {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BytesCustom", arg0, arg1)
ret0, _ := ret[0].([]byte)
return ret0
}
// BytesCustom indicates an expected call of BytesCustom.
func (mr *MockRandomMockRecorder) BytesCustom(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustom", reflect.TypeOf((*MockRandom)(nil).BytesCustom), arg0, arg1)
}
// BytesCustomErr mocks base method.
func (m *MockRandom) BytesCustomErr(arg0 int, arg1 []byte) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BytesCustomErr", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BytesCustomErr indicates an expected call of BytesCustomErr.
func (mr *MockRandomMockRecorder) BytesCustomErr(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustomErr", reflect.TypeOf((*MockRandom)(nil).BytesCustomErr), arg0, arg1)
}
// BytesErr mocks base method.
func (m *MockRandom) BytesErr() ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BytesErr")
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BytesErr indicates an expected call of BytesErr.
func (mr *MockRandomMockRecorder) BytesErr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesErr", reflect.TypeOf((*MockRandom)(nil).BytesErr))
}
// Int mocks base method.
func (m *MockRandom) Int(arg0 *big.Int) *big.Int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Int", arg0)
ret0, _ := ret[0].(*big.Int)
return ret0
}
// Int indicates an expected call of Int.
func (mr *MockRandomMockRecorder) Int(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int", reflect.TypeOf((*MockRandom)(nil).Int), arg0)
}
// IntErr mocks base method.
func (m *MockRandom) IntErr(arg0 *big.Int) (*big.Int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntErr", arg0)
ret0, _ := ret[0].(*big.Int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IntErr indicates an expected call of IntErr.
func (mr *MockRandomMockRecorder) IntErr(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntErr", reflect.TypeOf((*MockRandom)(nil).IntErr), arg0)
}
// Integer mocks base method.
func (m *MockRandom) Integer(arg0 int) int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Integer", arg0)
ret0, _ := ret[0].(int)
return ret0
}
// Integer indicates an expected call of Integer.
func (mr *MockRandomMockRecorder) Integer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Integer", reflect.TypeOf((*MockRandom)(nil).Integer), arg0)
}
// IntegerErr mocks base method.
func (m *MockRandom) IntegerErr(arg0 int) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntegerErr", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IntegerErr indicates an expected call of IntegerErr.
func (mr *MockRandomMockRecorder) IntegerErr(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntegerErr", reflect.TypeOf((*MockRandom)(nil).IntegerErr), arg0)
}
// Read mocks base method.
func (m *MockRandom) Read(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read.
func (mr *MockRandomMockRecorder) Read(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockRandom)(nil).Read), arg0)
}
// StringCustom mocks base method.
func (m *MockRandom) StringCustom(arg0 int, arg1 string) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StringCustom", arg0, arg1)
ret0, _ := ret[0].(string)
return ret0
}
// StringCustom indicates an expected call of StringCustom.
func (mr *MockRandomMockRecorder) StringCustom(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustom", reflect.TypeOf((*MockRandom)(nil).StringCustom), arg0, arg1)
}
// StringCustomErr mocks base method.
func (m *MockRandom) StringCustomErr(arg0 int, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StringCustomErr", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// StringCustomErr indicates an expected call of StringCustomErr.
func (mr *MockRandomMockRecorder) StringCustomErr(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustomErr", reflect.TypeOf((*MockRandom)(nil).StringCustomErr), arg0, arg1)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/templates" "github.com/authelia/authelia/v4/internal/templates"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -64,6 +65,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce
return &SMTPNotifier{ return &SMTPNotifier{
config: config, config: config,
domain: domain, domain: domain,
random: &random.Cryptographical{},
tls: utils.NewTLSConfig(config.TLS, certPool), tls: utils.NewTLSConfig(config.TLS, certPool),
log: logging.Logger(), log: logging.Logger(),
opts: opts, opts: opts,
@ -74,6 +76,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce
type SMTPNotifier struct { type SMTPNotifier struct {
config *schema.SMTPNotifierConfiguration config *schema.SMTPNotifierConfiguration
domain string domain string
random random.Provider
tls *tls.Config tls *tls.Config
log *logrus.Logger log *logrus.Logger
opts []gomail.Option opts []gomail.Option
@ -104,10 +107,10 @@ func (n *SMTPNotifier) StartupCheck() (err error) {
func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject string, et *templates.EmailTemplate, data any) (err error) { func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject string, et *templates.EmailTemplate, data any) (err error) {
msg := gomail.NewMsg( msg := gomail.NewMsg(
gomail.WithMIMEVersion(gomail.Mime10), gomail.WithMIMEVersion(gomail.Mime10),
gomail.WithBoundary(utils.RandomString(30, utils.CharSetAlphaNumeric)), gomail.WithBoundary(n.random.StringCustom(30, random.CharSetAlphaNumeric)),
) )
setMessageID(msg, n.domain) n.setMessageID(msg, n.domain)
if err = msg.From(n.config.Sender.String()); err != nil { if err = msg.From(n.config.Sender.String()); err != nil {
return fmt.Errorf("notifier: smtp: failed to set from address: %w", err) return fmt.Errorf("notifier: smtp: failed to set from address: %w", err)
@ -161,10 +164,10 @@ func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject
return nil return nil
} }
func setMessageID(msg *gomail.Msg, domain string) { func (n *SMTPNotifier) setMessageID(msg *gomail.Msg, domain string) {
rn, _ := utils.RandomInt(100000000) rn := n.random.Integer(100000000)
rm, _ := utils.RandomInt(10000) rm := n.random.Integer(10000)
rs := utils.RandomString(17, utils.CharSetAlphaNumeric) rs := n.random.StringCustom(17, random.CharSetAlphaNumeric)
pid := os.Getpid() + rm pid := os.Getpid() + rm
msg.SetMessageIDWithValue(fmt.Sprintf("%d.%d%d.%s@%s", pid, rn, rm, rs, domain)) msg.SetMessageIDWithValue(fmt.Sprintf("%d.%d%d.%s@%s", pid, rn, rm, rs, domain))

View File

@ -0,0 +1,43 @@
package random
const (
// DefaultN is the default value of n.
DefaultN = 72
)
const (
// CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars.
CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz"
// CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars.
CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
// CharSetAlphabetic are literally just valid alphabetic printable ASCII chars.
CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper
// CharSetNumeric are literally just valid numeric chars.
CharSetNumeric = "0123456789"
// CharSetNumericHex are literally just valid hexadecimal printable ASCII chars.
CharSetNumericHex = CharSetNumeric + "ABCDEF"
// CharSetSymbolic are literally just valid symbolic printable ASCII chars.
CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
// CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters.
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
CharSetSymbolicRFC3986Unreserved = "-._~"
// CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars.
CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric
// CharSetASCII are literally just valid printable ASCII chars.
CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic
// CharSetRFC3986Unreserved are RFC3986 unreserved characters.
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved
// CharSetUnambiguousUpper are a set of unambiguous uppercase characters.
CharSetUnambiguousUpper = "ABCDEFGHJKLMNOPQRTUVWYXZ2346789"
)

View File

@ -0,0 +1,136 @@
package random
import (
"crypto/rand"
"fmt"
"io"
"math/big"
)
// Cryptographical is the production random.Provider which uses crypto/rand.
type Cryptographical struct{}
// Read implements the io.Reader interface.
func (r *Cryptographical) Read(p []byte) (n int, err error) {
return io.ReadFull(rand.Reader, p)
}
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function returns it.
func (r *Cryptographical) BytesErr() (data []byte, err error) {
data = make([]byte, DefaultN)
_, err = rand.Read(data)
return data, err
}
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
func (r *Cryptographical) Bytes() (data []byte) {
data, _ = r.BytesErr()
return data
}
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
// returns it.
func (r *Cryptographical) BytesCustomErr(n int, charset []byte) (data []byte, err error) {
if n < 1 {
n = DefaultN
}
data = make([]byte, n)
if _, err = rand.Read(data); err != nil {
return nil, err
}
t := len(charset)
for i := 0; i < n; i++ {
data[i] = charset[data[i]%byte(t)]
}
return data, nil
}
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
func (r *Cryptographical) StringCustomErr(n int, characters string) (data string, err error) {
var d []byte
if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil {
return "", err
}
return string(d), nil
}
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values.
// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
// ignores it.
func (r *Cryptographical) BytesCustom(n int, charset []byte) (data []byte) {
data, _ = r.BytesCustomErr(n, charset)
return data
}
// StringCustom is an overload of BytesCustom which takes a characters string and returns a string.
func (r *Cryptographical) StringCustom(n int, characters string) (data string) {
return string(r.BytesCustom(n, []byte(characters)))
}
// IntErr returns a random *big.Int error combination with a maximum of max.
func (r *Cryptographical) IntErr(max *big.Int) (value *big.Int, err error) {
if max == nil {
return nil, fmt.Errorf("max is required")
}
if max.Sign() <= 0 {
return nil, fmt.Errorf("max must be 1 or more")
}
return rand.Int(rand.Reader, max)
}
// Int returns a random *big.Int with a maximum of max.
func (r *Cryptographical) Int(max *big.Int) (value *big.Int) {
var err error
if value, err = r.IntErr(max); err != nil {
return big.NewInt(-1)
}
return value
}
// IntegerErr returns a random int error combination with a maximum of n.
func (r *Cryptographical) IntegerErr(n int) (value int, err error) {
if n <= 0 {
return 0, fmt.Errorf("n must be more than 0")
}
max := big.NewInt(int64(n))
var result *big.Int
if result, err = r.IntErr(max); err != nil {
return 0, err
}
value = int(result.Int64())
if value < 0 {
return 0, fmt.Errorf("generated number is too big for int")
}
return value, nil
}
// Integer returns a random int with a maximum of n.
func (r *Cryptographical) Integer(n int) (value int) {
value, _ = r.IntegerErr(n)
return value
}

View File

@ -0,0 +1,126 @@
package random
import (
"fmt"
"math/big"
"math/rand"
"time"
)
// NewMathematical runs rand.Seed with the current time and returns a random.Provider, specifically *random.Mathematical.
func NewMathematical() *Mathematical {
rand.Seed(time.Now().UnixNano())
return &Mathematical{}
}
// Mathematical is the random.Provider which uses math/rand and is COMPLETELY UNSAFE FOR PRODUCTION IN MOST SITUATIONS.
// Use random.Cryptographical instead.
type Mathematical struct{}
// Read implements the io.Reader interface.
func (r *Mathematical) Read(p []byte) (n int, err error) {
return rand.Read(p) //nolint:gosec
}
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function returns it.
func (r *Mathematical) BytesErr() (data []byte, err error) {
data = make([]byte, DefaultN)
if _, err = rand.Read(data); err != nil { //nolint:gosec
return nil, err
}
return data, nil
}
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
func (r *Mathematical) Bytes() (data []byte) {
data, _ = r.BytesErr()
return data
}
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
// returns it.
func (r *Mathematical) BytesCustomErr(n int, charset []byte) (data []byte, err error) {
if n < 1 {
n = DefaultN
}
data = make([]byte, n)
if _, err = rand.Read(data); err != nil { //nolint:gosec
return nil, err
}
t := len(charset)
for i := 0; i < n; i++ {
data[i] = charset[data[i]%byte(t)]
}
return data, nil
}
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
func (r *Mathematical) StringCustomErr(n int, characters string) (data string, err error) {
var d []byte
if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil {
return "", err
}
return string(d), nil
}
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values.
// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
// ignores it.
func (r *Mathematical) BytesCustom(n int, charset []byte) (data []byte) {
data, _ = r.BytesCustomErr(n, charset)
return data
}
// StringCustom is an overload of BytesCustom which takes a characters string and returns a string.
func (r *Mathematical) StringCustom(n int, characters string) (data string) {
return string(r.BytesCustom(n, []byte(characters)))
}
// IntErr returns a random *big.Int error combination with a maximum of max.
func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) {
if max == nil {
return nil, fmt.Errorf("max is required")
}
if max.Sign() <= 0 {
return nil, fmt.Errorf("max must be 1 or more")
}
return big.NewInt(int64(rand.Intn(max.Sign()))), nil //nolint:gosec
}
// Int returns a random *big.Int with a maximum of max.
func (r *Mathematical) Int(max *big.Int) (value *big.Int) {
var err error
if value, err = r.IntErr(max); err != nil {
return big.NewInt(-1)
}
return value
}
// IntegerErr returns a random int error combination with a maximum of n.
func (r *Mathematical) IntegerErr(n int) (output int, err error) {
return r.Integer(n), nil
}
// Integer returns a random int with a maximum of n.
func (r *Mathematical) Integer(n int) int {
return rand.Intn(n) //nolint:gosec
}

View File

@ -0,0 +1,46 @@
package random
import (
"io"
"math/big"
)
// Provider of random functions and functionality.
type Provider interface {
io.Reader
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function returns it.
BytesErr() (data []byte, err error)
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
Bytes() (data []byte)
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
// returns it.
BytesCustomErr(n int, charset []byte) (data []byte, err error)
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
StringCustomErr(n int, characters string) (data string, err error)
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided
// values. If n is less than 1 then DefaultN is used instead.
BytesCustom(n int, charset []byte) (data []byte)
// StringCustom is an overload of GenerateCustom which takes a characters string and returns a string.
StringCustom(n int, characters string) (data string)
// IntErr returns a random *big.Int error combination with a maximum of max.
IntErr(max *big.Int) (value *big.Int, err error)
// Int returns a random *big.Int with a maximum of max.
Int(max *big.Int) (value *big.Int)
// IntegerErr returns a random int error combination with a maximum of n.
IntegerErr(n int) (value int, err error)
// Integer returns a random integer with a maximum of n.
Integer(n int) (value int)
}

View File

@ -15,8 +15,8 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/random"
"github.com/authelia/authelia/v4/internal/templates" "github.com/authelia/authelia/v4/internal/templates"
"github.com/authelia/authelia/v4/internal/utils"
) )
// ServeTemplatedFile serves a templated version of a specified file, // ServeTemplatedFile serves a templated version of a specified file,
@ -46,7 +46,7 @@ func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middle
ctx.SetContentTypeTextPlain() ctx.SetContentTypeTextPlain()
} }
nonce := utils.RandomString(32, utils.CharSetAlphaNumeric) nonce := ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric)
switch { switch {
case ctx.Configuration.Server.Headers.CSPTemplate != "": case ctx.Configuration.Server.Headers.CSPTemplate != "":
@ -78,7 +78,7 @@ func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) mid
if spec { if spec {
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, tmplCSPSwagger) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, tmplCSPSwagger)
} else { } else {
nonce = utils.RandomString(32, utils.CharSetAlphaNumeric) nonce = ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric)
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwaggerNonce, nonce, nonce)) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwaggerNonce, nonce, nonce))
} }

View File

@ -97,40 +97,6 @@ const (
timeUnixEpochAsMicrosoftNTEpoch uint64 = 116444736000000000 timeUnixEpochAsMicrosoftNTEpoch uint64 = 116444736000000000
) )
const (
// CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars.
CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz"
// CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars.
CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
// CharSetAlphabetic are literally just valid alphabetic printable ASCII chars.
CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper
// CharSetNumeric are literally just valid numeric chars.
CharSetNumeric = "0123456789"
// CharSetNumericHex are literally just valid hexadecimal printable ASCII chars.
CharSetNumericHex = CharSetNumeric + "ABCDEF"
// CharSetSymbolic are literally just valid symbolic printable ASCII chars.
CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
// CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters.
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
CharSetSymbolicRFC3986Unreserved = "-._~"
// CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars.
CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric
// CharSetASCII are literally just valid printable ASCII chars.
CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic
// CharSetRFC3986Unreserved are RFC3986 unreserved characters.
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved
)
var htmlEscaper = strings.NewReplacer( var htmlEscaper = strings.NewReplacer(
"&", "&amp;", "&", "&amp;",
"<", "&lt;", "<", "&lt;",

View File

@ -1,5 +1,11 @@
package utils package utils
import (
"github.com/authelia/authelia/v4/internal/random"
)
const ( const (
testStringInput = "abcdefghijkl" testStringInput = "abcdefghijkl"
) )
var r = &random.Cryptographical{}

View File

@ -573,50 +573,3 @@ loop:
return extKeyUsage return extKeyUsage
} }
// RandomString returns a random string with a given length with values from the provided characters. When crypto is set
// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
// excluding when the task is time sensitive and would not benefit from extra randomness.
func RandomString(n int, characters string) (randomString string) {
return string(RandomBytes(n, characters))
}
// RandomBytes returns a random []byte with a given length with values from the provided characters. When crypto is set
// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
// excluding when the task is time sensitive and would not benefit from extra randomness.
func RandomBytes(n int, characters string) (bytes []byte) {
bytes = make([]byte, n)
_, _ = rand.Read(bytes)
for i, b := range bytes {
bytes[i] = characters[b%byte(len(characters))]
}
return bytes
}
func RandomInt(n int) (int, error) {
if n <= 0 {
return 0, fmt.Errorf("n must be more than 0")
}
max := big.NewInt(int64(n))
if !max.IsUint64() {
return 0, fmt.Errorf("generated max is negative")
}
value, err := rand.Int(rand.Reader, max)
if err != nil {
return 0, err
}
output := int(value.Int64())
if output < 0 {
return 0, fmt.Errorf("generated number is too big for int")
}
return output, nil
}

View File

@ -7,6 +7,8 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/authelia/authelia/v4/internal/random"
) )
func TestShouldHashString(t *testing.T) { func TestShouldHashString(t *testing.T) {
@ -22,7 +24,7 @@ func TestShouldHashString(t *testing.T) {
assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum) assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum)
assert.NotEqual(t, sum, anotherSum) assert.NotEqual(t, sum, anotherSum)
randomInput := RandomString(40, CharSetAlphaNumeric) randomInput := r.StringCustom(40, random.CharSetAlphaNumeric)
randomSum := HashSHA256FromString(randomInput) randomSum := HashSHA256FromString(randomInput)
assert.NotEqual(t, randomSum, sum) assert.NotEqual(t, randomSum, sum)
@ -38,7 +40,7 @@ func TestShouldHashPath(t *testing.T) {
err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600) err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600)
assert.NoError(t, err) assert.NoError(t, err)
err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(RandomString(40, CharSetAlphaNumeric)+"\n"), 0600) err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(r.StringCustom(40, random.CharSetAlphaNumeric)+"\n"), 0600)
assert.NoError(t, err) assert.NoError(t, err)
sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile")) sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile"))

View File

@ -53,13 +53,6 @@ func TestStringJoinDelimitedEscaped(t *testing.T) {
} }
} }
func TestShouldNotGenerateSameRandomString(t *testing.T) {
randomStringOne := RandomString(10, CharSetAlphaNumeric)
randomStringTwo := RandomString(10, CharSetAlphaNumeric)
assert.NotEqual(t, randomStringOne, randomStringTwo)
}
func TestShouldDetectAlphaNumericString(t *testing.T) { func TestShouldDetectAlphaNumericString(t *testing.T) {
assert.True(t, IsStringAlphaNumeric("abc")) assert.True(t, IsStringAlphaNumeric("abc"))
assert.True(t, IsStringAlphaNumeric("abc123")) assert.True(t, IsStringAlphaNumeric("abc123"))