refactor(random): add random provider (#4712)
This adds a random provider which makes usage of random operations mockable, and may allow us in the future to swap out the Cryptographical CPU random generator with dedicated hardware random generators.pull/4714/head^2
parent
f223975e79
commit
fc5ea5b485
2
go.mod
2
go.mod
|
@ -40,7 +40,6 @@ require (
|
|||
github.com/trustelem/zxcvbn v1.0.1
|
||||
github.com/valyala/fasthttp v1.43.0
|
||||
github.com/wneessen/go-mail v0.3.7
|
||||
golang.org/x/net v0.5.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/term v0.4.0
|
||||
golang.org/x/text v0.6.0
|
||||
|
@ -110,6 +109,7 @@ require (
|
|||
github.com/ysmood/leakless v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.1.0 // indirect
|
||||
golang.org/x/mod v0.6.0 // indirect
|
||||
golang.org/x/net v0.5.0 // indirect
|
||||
golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b // indirect
|
||||
golang.org/x/sys v0.4.0 // indirect
|
||||
golang.org/x/tools v0.2.0 // indirect
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -8,7 +9,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/notification"
|
||||
"github.com/authelia/authelia/v4/internal/ntp"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/regulation"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
|
@ -43,6 +44,9 @@ func NewCmdCtx() *CmdCtx {
|
|||
cancel: cancel,
|
||||
group: group,
|
||||
log: logging.Logger(),
|
||||
providers: middlewares.Providers{
|
||||
Random: &random.Cryptographical{},
|
||||
},
|
||||
config: &schema.Configuration{},
|
||||
}
|
||||
}
|
||||
|
@ -139,48 +143,43 @@ func (ctx *CmdCtx) LoadProviders() (warns, errs []error) {
|
|||
return warns, errs
|
||||
}
|
||||
|
||||
storage := getStorageProvider(ctx)
|
||||
ctx.providers.StorageProvider = getStorageProvider(ctx)
|
||||
|
||||
providers := middlewares.Providers{
|
||||
Authorizer: authorization.NewAuthorizer(ctx.config),
|
||||
NTP: ntp.NewProvider(&ctx.config.NTP),
|
||||
PasswordPolicy: middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy),
|
||||
Regulator: regulation.NewRegulator(ctx.config.Regulation, storage, utils.RealClock{}),
|
||||
SessionProvider: session.NewProvider(ctx.config.Session, ctx.trusted),
|
||||
StorageProvider: storage,
|
||||
TOTP: totp.NewTimeBasedProvider(ctx.config.TOTP),
|
||||
}
|
||||
ctx.providers.Authorizer = authorization.NewAuthorizer(ctx.config)
|
||||
ctx.providers.NTP = ntp.NewProvider(&ctx.config.NTP)
|
||||
ctx.providers.PasswordPolicy = middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy)
|
||||
ctx.providers.Regulator = regulation.NewRegulator(ctx.config.Regulation, ctx.providers.StorageProvider, utils.RealClock{})
|
||||
ctx.providers.SessionProvider = session.NewProvider(ctx.config.Session, ctx.trusted)
|
||||
ctx.providers.TOTP = totp.NewTimeBasedProvider(ctx.config.TOTP)
|
||||
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case ctx.config.AuthenticationBackend.File != nil:
|
||||
providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File)
|
||||
ctx.providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File)
|
||||
case ctx.config.AuthenticationBackend.LDAP != nil:
|
||||
providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted)
|
||||
ctx.providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted)
|
||||
}
|
||||
|
||||
if providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil {
|
||||
if ctx.providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case ctx.config.Notifier.SMTP != nil:
|
||||
providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted)
|
||||
ctx.providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted)
|
||||
case ctx.config.Notifier.FileSystem != nil:
|
||||
providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem)
|
||||
ctx.providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem)
|
||||
}
|
||||
|
||||
if providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, storage); err != nil {
|
||||
if ctx.providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, ctx.providers.StorageProvider); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if ctx.config.Telemetry.Metrics.Enabled {
|
||||
providers.Metrics = metrics.NewPrometheus()
|
||||
ctx.providers.Metrics = metrics.NewPrometheus()
|
||||
}
|
||||
|
||||
ctx.providers = providers
|
||||
|
||||
return warns, errs
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package commands
|
|||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
|
@ -262,7 +261,7 @@ func (ctx *CmdCtx) CryptoGenerateRunE(cmd *cobra.Command, args []string) (err er
|
|||
privateKey any
|
||||
)
|
||||
|
||||
if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil {
|
||||
if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -279,7 +278,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string)
|
|||
privateKey any
|
||||
)
|
||||
|
||||
if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil {
|
||||
if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -326,7 +325,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string)
|
|||
|
||||
b.Reset()
|
||||
|
||||
if csr, err = x509.CreateCertificateRequest(rand.Reader, template, privateKey); err != nil {
|
||||
if csr, err = x509.CreateCertificateRequest(ctx.providers.Random, template, privateKey); err != nil {
|
||||
return fmt.Errorf("failed to create certificate request: %w", err)
|
||||
}
|
||||
|
||||
|
@ -366,7 +365,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string,
|
|||
signatureKey = caPrivateKey
|
||||
}
|
||||
|
||||
if template, err = cryptoGetCertificateFromCmd(cmd); err != nil {
|
||||
if template, err = ctx.cryptoGetCertificateFromCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -423,7 +422,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string,
|
|||
|
||||
b.Reset()
|
||||
|
||||
if certificate, err = x509.CreateCertificate(rand.Reader, template, parent, publicKey, signatureKey); err != nil {
|
||||
if certificate, err = x509.CreateCertificate(ctx.providers.Random, template, parent, publicKey, signatureKey); err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
|
@ -130,7 +129,7 @@ func cryptoGetWritePathsFromCmd(cmd *cobra.Command) (privateKey, publicKey strin
|
|||
return filepath.Join(dir, private), filepath.Join(dir, public), nil
|
||||
}
|
||||
|
||||
func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) {
|
||||
func (ctx *CmdCtx) cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) {
|
||||
switch cmd.Parent().Use {
|
||||
case cmdUseRSA:
|
||||
var (
|
||||
|
@ -141,7 +140,7 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if privateKey, err = rsa.GenerateKey(rand.Reader, bits); err != nil {
|
||||
if privateKey, err = rsa.GenerateKey(ctx.providers.Random, bits); err != nil {
|
||||
return nil, fmt.Errorf("generating RSA private key resulted in an error: %w", err)
|
||||
}
|
||||
case cmdUseECDSA:
|
||||
|
@ -158,11 +157,11 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error)
|
|||
return nil, fmt.Errorf("invalid curve '%s' was specified: curve must be P224, P256, P384, or P521", curveStr)
|
||||
}
|
||||
|
||||
if privateKey, err = ecdsa.GenerateKey(curve, rand.Reader); err != nil {
|
||||
if privateKey, err = ecdsa.GenerateKey(curve, ctx.providers.Random); err != nil {
|
||||
return nil, fmt.Errorf("generating ECDSA private key resulted in an error: %w", err)
|
||||
}
|
||||
case cmdUseEd25519:
|
||||
if _, privateKey, err = ed25519.GenerateKey(rand.Reader); err != nil {
|
||||
if _, privateKey, err = ed25519.GenerateKey(ctx.providers.Random); err != nil {
|
||||
return nil, fmt.Errorf("generating Ed25519 private key resulted in an error: %w", err)
|
||||
}
|
||||
}
|
||||
|
@ -336,7 +335,7 @@ func cryptoGetSubjectFromCmd(cmd *cobra.Command) (subject *pkix.Name, err error)
|
|||
}, nil
|
||||
}
|
||||
|
||||
func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) {
|
||||
func (ctx *CmdCtx) cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) {
|
||||
var (
|
||||
ca bool
|
||||
subject *pkix.Name
|
||||
|
@ -378,7 +377,7 @@ func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certific
|
|||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
|
||||
if serialNumber, err = rand.Int(rand.Reader, serialNumberLimit); err != nil {
|
||||
if serialNumber, err = ctx.providers.Random.IntErr(serialNumberLimit); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
"github.com/authelia/authelia/v4/internal/model"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
"github.com/authelia/authelia/v4/internal/totp"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
|
@ -983,7 +984,8 @@ func (ctx *CmdCtx) StorageUserTOTPExportPNGRunE(cmd *cobra.Command, _ []string)
|
|||
}
|
||||
|
||||
if dir == "" {
|
||||
dir = utils.RandomString(8, utils.CharSetAlphaNumeric)
|
||||
rand := &random.Cryptographical{}
|
||||
dir = rand.StringCustom(8, random.CharSetAlphaNumeric)
|
||||
}
|
||||
|
||||
if _, err = os.Stat(dir); !os.IsNotExist(err) {
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"golang.org/x/term"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
)
|
||||
|
||||
func recoverErr(i any) error {
|
||||
|
@ -77,29 +77,29 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
|
|||
|
||||
switch c {
|
||||
case "ascii":
|
||||
charset = utils.CharSetASCII
|
||||
charset = random.CharSetASCII
|
||||
case "alphanumeric":
|
||||
charset = utils.CharSetAlphaNumeric
|
||||
charset = random.CharSetAlphaNumeric
|
||||
case "alphanumeric-lower":
|
||||
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric
|
||||
charset = random.CharSetAlphabeticLower + random.CharSetNumeric
|
||||
case "alphanumeric-upper":
|
||||
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric
|
||||
charset = random.CharSetAlphabeticUpper + random.CharSetNumeric
|
||||
case "alphabetic":
|
||||
charset = utils.CharSetAlphabetic
|
||||
charset = random.CharSetAlphabetic
|
||||
case "alphabetic-lower":
|
||||
charset = utils.CharSetAlphabeticLower
|
||||
charset = random.CharSetAlphabeticLower
|
||||
case "alphabetic-upper":
|
||||
charset = utils.CharSetAlphabeticUpper
|
||||
charset = random.CharSetAlphabeticUpper
|
||||
case "numeric-hex":
|
||||
charset = utils.CharSetNumericHex
|
||||
charset = random.CharSetNumericHex
|
||||
case "numeric":
|
||||
charset = utils.CharSetNumeric
|
||||
charset = random.CharSetNumeric
|
||||
case "rfc3986":
|
||||
charset = utils.CharSetRFC3986Unreserved
|
||||
charset = random.CharSetRFC3986Unreserved
|
||||
case "rfc3986-lower":
|
||||
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved
|
||||
charset = random.CharSetAlphabeticLower + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved
|
||||
case "rfc3986-upper":
|
||||
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved
|
||||
charset = random.CharSetAlphabeticUpper + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved
|
||||
default:
|
||||
return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c)
|
||||
}
|
||||
|
@ -109,7 +109,9 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar
|
|||
}
|
||||
}
|
||||
|
||||
return utils.RandomString(n, charset), nil
|
||||
rand := &random.Cryptographical{}
|
||||
|
||||
return rand.StringCustom(n, charset), nil
|
||||
}
|
||||
|
||||
func termReadConfirmation(flags *pflag.FlagSet, name, prompt, confirmation string) (confirmed bool, err error) {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
"github.com/authelia/authelia/v4/internal/model"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
|
@ -128,6 +129,7 @@ func TestShouldCallNextWithAutheliaCtx(t *testing.T) {
|
|||
providers := middlewares.Providers{
|
||||
UserProvider: userProvider,
|
||||
SessionProvider: sessionProvider,
|
||||
Random: random.NewMathematical(),
|
||||
}
|
||||
nextCalled := false
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
|
@ -73,8 +74,8 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
|
|||
defer mock.Close()
|
||||
|
||||
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
||||
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
||||
mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http")
|
||||
mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host")
|
||||
|
||||
mock.StorageMock.EXPECT().
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
|
@ -95,8 +96,8 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
|
|||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
|
||||
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
||||
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
||||
mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http")
|
||||
mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host")
|
||||
|
||||
mock.StorageMock.EXPECT().
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync"
|
||||
|
@ -62,7 +61,7 @@ func movingAverageIteration(value time.Duration, history int, successful bool, c
|
|||
}
|
||||
|
||||
func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
|
||||
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
|
||||
randomDelayMs, err := ctx.Providers.Random.IntErr(big.NewInt(maxRandomMs))
|
||||
if err != nil {
|
||||
return float64(maxRandomMs)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
)
|
||||
|
||||
func TestTimingAttackDelayAverages(t *testing.T) {
|
||||
|
@ -45,7 +46,12 @@ func TestTimingAttackDelayCalculations(t *testing.T) {
|
|||
avgExecDurationMs := 1000.0
|
||||
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
|
||||
|
||||
ctx := &AutheliaCtx{Logger: logging.Logger().WithFields(logrus.Fields{})}
|
||||
ctx := &AutheliaCtx{
|
||||
Logger: logging.Logger().WithFields(logrus.Fields{}),
|
||||
Providers: Providers{
|
||||
Random: &random.Cryptographical{},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(ctx, execDuration, avgExecDurationMs, 250, 85, false)
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/notification"
|
||||
"github.com/authelia/authelia/v4/internal/ntp"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/regulation"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
|
@ -44,6 +45,7 @@ type Providers struct {
|
|||
Templates *templates.Provider
|
||||
TOTP totp.Provider
|
||||
PasswordPolicy PasswordPolicyProvider
|
||||
Random random.Provider
|
||||
}
|
||||
|
||||
// RequestHandler represents an Authelia request handler.
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/authorization"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/regulation"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
"github.com/authelia/authelia/v4/internal/templates"
|
||||
|
@ -34,6 +35,7 @@ type MockAutheliaCtx struct {
|
|||
StorageMock *MockStorage
|
||||
NotifierMock *MockNotifier
|
||||
TOTPMock *MockTOTP
|
||||
RandomMock *MockRandom
|
||||
|
||||
UserSession *session.UserSession
|
||||
|
||||
|
@ -98,6 +100,10 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx {
|
|||
mockAuthelia.TOTPMock = NewMockTOTP(mockAuthelia.Ctrl)
|
||||
providers.TOTP = mockAuthelia.TOTPMock
|
||||
|
||||
mockAuthelia.RandomMock = NewMockRandom(mockAuthelia.Ctrl)
|
||||
|
||||
providers.Random = random.NewMathematical()
|
||||
|
||||
var err error
|
||||
|
||||
if providers.Templates, err = templates.New(templates.Config{}); err != nil {
|
||||
|
|
|
@ -8,3 +8,4 @@ package mocks
|
|||
//go:generate mockgen -package mocks -destination totp.go -mock_names Provider=MockTOTP github.com/authelia/authelia/v4/internal/totp Provider
|
||||
//go:generate mockgen -package mocks -destination storage.go -mock_names Provider=MockStorage github.com/authelia/authelia/v4/internal/storage Provider
|
||||
//go:generate mockgen -package mocks -destination duo_api.go -mock_names API=MockAPI github.com/authelia/authelia/v4/internal/duo API
|
||||
//go:generate mockgen -package mocks -destination random.go -mock_names Provider=MockRandom github.com/authelia/authelia/v4/internal/random Provider
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/authelia/authelia/v4/internal/random (interfaces: Provider)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
big "math/big"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRandom is a mock of Provider interface.
|
||||
type MockRandom struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRandomMockRecorder
|
||||
}
|
||||
|
||||
// MockRandomMockRecorder is the mock recorder for MockRandom.
|
||||
type MockRandomMockRecorder struct {
|
||||
mock *MockRandom
|
||||
}
|
||||
|
||||
// NewMockRandom creates a new mock instance.
|
||||
func NewMockRandom(ctrl *gomock.Controller) *MockRandom {
|
||||
mock := &MockRandom{ctrl: ctrl}
|
||||
mock.recorder = &MockRandomMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRandom) EXPECT() *MockRandomMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Bytes mocks base method.
|
||||
func (m *MockRandom) Bytes() []byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Bytes")
|
||||
ret0, _ := ret[0].([]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Bytes indicates an expected call of Bytes.
|
||||
func (mr *MockRandomMockRecorder) Bytes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockRandom)(nil).Bytes))
|
||||
}
|
||||
|
||||
// BytesCustom mocks base method.
|
||||
func (m *MockRandom) BytesCustom(arg0 int, arg1 []byte) []byte {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BytesCustom", arg0, arg1)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BytesCustom indicates an expected call of BytesCustom.
|
||||
func (mr *MockRandomMockRecorder) BytesCustom(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustom", reflect.TypeOf((*MockRandom)(nil).BytesCustom), arg0, arg1)
|
||||
}
|
||||
|
||||
// BytesCustomErr mocks base method.
|
||||
func (m *MockRandom) BytesCustomErr(arg0 int, arg1 []byte) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BytesCustomErr", arg0, arg1)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BytesCustomErr indicates an expected call of BytesCustomErr.
|
||||
func (mr *MockRandomMockRecorder) BytesCustomErr(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustomErr", reflect.TypeOf((*MockRandom)(nil).BytesCustomErr), arg0, arg1)
|
||||
}
|
||||
|
||||
// BytesErr mocks base method.
|
||||
func (m *MockRandom) BytesErr() ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BytesErr")
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BytesErr indicates an expected call of BytesErr.
|
||||
func (mr *MockRandomMockRecorder) BytesErr() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesErr", reflect.TypeOf((*MockRandom)(nil).BytesErr))
|
||||
}
|
||||
|
||||
// Int mocks base method.
|
||||
func (m *MockRandom) Int(arg0 *big.Int) *big.Int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Int", arg0)
|
||||
ret0, _ := ret[0].(*big.Int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Int indicates an expected call of Int.
|
||||
func (mr *MockRandomMockRecorder) Int(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int", reflect.TypeOf((*MockRandom)(nil).Int), arg0)
|
||||
}
|
||||
|
||||
// IntErr mocks base method.
|
||||
func (m *MockRandom) IntErr(arg0 *big.Int) (*big.Int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IntErr", arg0)
|
||||
ret0, _ := ret[0].(*big.Int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IntErr indicates an expected call of IntErr.
|
||||
func (mr *MockRandomMockRecorder) IntErr(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntErr", reflect.TypeOf((*MockRandom)(nil).IntErr), arg0)
|
||||
}
|
||||
|
||||
// Integer mocks base method.
|
||||
func (m *MockRandom) Integer(arg0 int) int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Integer", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Integer indicates an expected call of Integer.
|
||||
func (mr *MockRandomMockRecorder) Integer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Integer", reflect.TypeOf((*MockRandom)(nil).Integer), arg0)
|
||||
}
|
||||
|
||||
// IntegerErr mocks base method.
|
||||
func (m *MockRandom) IntegerErr(arg0 int) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IntegerErr", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IntegerErr indicates an expected call of IntegerErr.
|
||||
func (mr *MockRandomMockRecorder) IntegerErr(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntegerErr", reflect.TypeOf((*MockRandom)(nil).IntegerErr), arg0)
|
||||
}
|
||||
|
||||
// Read mocks base method.
|
||||
func (m *MockRandom) Read(arg0 []byte) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read.
|
||||
func (mr *MockRandomMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockRandom)(nil).Read), arg0)
|
||||
}
|
||||
|
||||
// StringCustom mocks base method.
|
||||
func (m *MockRandom) StringCustom(arg0 int, arg1 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StringCustom", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StringCustom indicates an expected call of StringCustom.
|
||||
func (mr *MockRandomMockRecorder) StringCustom(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustom", reflect.TypeOf((*MockRandom)(nil).StringCustom), arg0, arg1)
|
||||
}
|
||||
|
||||
// StringCustomErr mocks base method.
|
||||
func (m *MockRandom) StringCustomErr(arg0 int, arg1 string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StringCustomErr", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// StringCustomErr indicates an expected call of StringCustomErr.
|
||||
func (mr *MockRandomMockRecorder) StringCustomErr(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustomErr", reflect.TypeOf((*MockRandom)(nil).StringCustomErr), arg0, arg1)
|
||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/templates"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
@ -64,6 +65,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce
|
|||
return &SMTPNotifier{
|
||||
config: config,
|
||||
domain: domain,
|
||||
random: &random.Cryptographical{},
|
||||
tls: utils.NewTLSConfig(config.TLS, certPool),
|
||||
log: logging.Logger(),
|
||||
opts: opts,
|
||||
|
@ -74,6 +76,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce
|
|||
type SMTPNotifier struct {
|
||||
config *schema.SMTPNotifierConfiguration
|
||||
domain string
|
||||
random random.Provider
|
||||
tls *tls.Config
|
||||
log *logrus.Logger
|
||||
opts []gomail.Option
|
||||
|
@ -104,10 +107,10 @@ func (n *SMTPNotifier) StartupCheck() (err error) {
|
|||
func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject string, et *templates.EmailTemplate, data any) (err error) {
|
||||
msg := gomail.NewMsg(
|
||||
gomail.WithMIMEVersion(gomail.Mime10),
|
||||
gomail.WithBoundary(utils.RandomString(30, utils.CharSetAlphaNumeric)),
|
||||
gomail.WithBoundary(n.random.StringCustom(30, random.CharSetAlphaNumeric)),
|
||||
)
|
||||
|
||||
setMessageID(msg, n.domain)
|
||||
n.setMessageID(msg, n.domain)
|
||||
|
||||
if err = msg.From(n.config.Sender.String()); err != nil {
|
||||
return fmt.Errorf("notifier: smtp: failed to set from address: %w", err)
|
||||
|
@ -161,10 +164,10 @@ func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject
|
|||
return nil
|
||||
}
|
||||
|
||||
func setMessageID(msg *gomail.Msg, domain string) {
|
||||
rn, _ := utils.RandomInt(100000000)
|
||||
rm, _ := utils.RandomInt(10000)
|
||||
rs := utils.RandomString(17, utils.CharSetAlphaNumeric)
|
||||
func (n *SMTPNotifier) setMessageID(msg *gomail.Msg, domain string) {
|
||||
rn := n.random.Integer(100000000)
|
||||
rm := n.random.Integer(10000)
|
||||
rs := n.random.StringCustom(17, random.CharSetAlphaNumeric)
|
||||
pid := os.Getpid() + rm
|
||||
|
||||
msg.SetMessageIDWithValue(fmt.Sprintf("%d.%d%d.%s@%s", pid, rn, rm, rs, domain))
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package random
|
||||
|
||||
const (
|
||||
// DefaultN is the default value of n.
|
||||
DefaultN = 72
|
||||
)
|
||||
|
||||
const (
|
||||
// CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars.
|
||||
CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
// CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars.
|
||||
CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
// CharSetAlphabetic are literally just valid alphabetic printable ASCII chars.
|
||||
CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper
|
||||
|
||||
// CharSetNumeric are literally just valid numeric chars.
|
||||
CharSetNumeric = "0123456789"
|
||||
|
||||
// CharSetNumericHex are literally just valid hexadecimal printable ASCII chars.
|
||||
CharSetNumericHex = CharSetNumeric + "ABCDEF"
|
||||
|
||||
// CharSetSymbolic are literally just valid symbolic printable ASCII chars.
|
||||
CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||
|
||||
// CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters.
|
||||
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
|
||||
CharSetSymbolicRFC3986Unreserved = "-._~"
|
||||
|
||||
// CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars.
|
||||
CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric
|
||||
|
||||
// CharSetASCII are literally just valid printable ASCII chars.
|
||||
CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic
|
||||
|
||||
// CharSetRFC3986Unreserved are RFC3986 unreserved characters.
|
||||
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
|
||||
CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved
|
||||
|
||||
// CharSetUnambiguousUpper are a set of unambiguous uppercase characters.
|
||||
CharSetUnambiguousUpper = "ABCDEFGHJKLMNOPQRTUVWYXZ2346789"
|
||||
)
|
|
@ -0,0 +1,136 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// Cryptographical is the production random.Provider which uses crypto/rand.
|
||||
type Cryptographical struct{}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *Cryptographical) Read(p []byte) (n int, err error) {
|
||||
return io.ReadFull(rand.Reader, p)
|
||||
}
|
||||
|
||||
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function returns it.
|
||||
func (r *Cryptographical) BytesErr() (data []byte, err error) {
|
||||
data = make([]byte, DefaultN)
|
||||
|
||||
_, err = rand.Read(data)
|
||||
|
||||
return data, err
|
||||
}
|
||||
|
||||
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
|
||||
func (r *Cryptographical) Bytes() (data []byte) {
|
||||
data, _ = r.BytesErr()
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
|
||||
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
|
||||
// returns it.
|
||||
func (r *Cryptographical) BytesCustomErr(n int, charset []byte) (data []byte, err error) {
|
||||
if n < 1 {
|
||||
n = DefaultN
|
||||
}
|
||||
|
||||
data = make([]byte, n)
|
||||
|
||||
if _, err = rand.Read(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := len(charset)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
data[i] = charset[data[i]%byte(t)]
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
|
||||
func (r *Cryptographical) StringCustomErr(n int, characters string) (data string, err error) {
|
||||
var d []byte
|
||||
|
||||
if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(d), nil
|
||||
}
|
||||
|
||||
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values.
|
||||
// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
|
||||
// ignores it.
|
||||
func (r *Cryptographical) BytesCustom(n int, charset []byte) (data []byte) {
|
||||
data, _ = r.BytesCustomErr(n, charset)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// StringCustom is an overload of BytesCustom which takes a characters string and returns a string.
|
||||
func (r *Cryptographical) StringCustom(n int, characters string) (data string) {
|
||||
return string(r.BytesCustom(n, []byte(characters)))
|
||||
}
|
||||
|
||||
// IntErr returns a random *big.Int error combination with a maximum of max.
|
||||
func (r *Cryptographical) IntErr(max *big.Int) (value *big.Int, err error) {
|
||||
if max == nil {
|
||||
return nil, fmt.Errorf("max is required")
|
||||
}
|
||||
|
||||
if max.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("max must be 1 or more")
|
||||
}
|
||||
|
||||
return rand.Int(rand.Reader, max)
|
||||
}
|
||||
|
||||
// Int returns a random *big.Int with a maximum of max.
|
||||
func (r *Cryptographical) Int(max *big.Int) (value *big.Int) {
|
||||
var err error
|
||||
|
||||
if value, err = r.IntErr(max); err != nil {
|
||||
return big.NewInt(-1)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// IntegerErr returns a random int error combination with a maximum of n.
|
||||
func (r *Cryptographical) IntegerErr(n int) (value int, err error) {
|
||||
if n <= 0 {
|
||||
return 0, fmt.Errorf("n must be more than 0")
|
||||
}
|
||||
|
||||
max := big.NewInt(int64(n))
|
||||
|
||||
var result *big.Int
|
||||
|
||||
if result, err = r.IntErr(max); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
value = int(result.Int64())
|
||||
|
||||
if value < 0 {
|
||||
return 0, fmt.Errorf("generated number is too big for int")
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Integer returns a random int with a maximum of n.
|
||||
func (r *Cryptographical) Integer(n int) (value int) {
|
||||
value, _ = r.IntegerErr(n)
|
||||
|
||||
return value
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewMathematical runs rand.Seed with the current time and returns a random.Provider, specifically *random.Mathematical.
|
||||
func NewMathematical() *Mathematical {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
return &Mathematical{}
|
||||
}
|
||||
|
||||
// Mathematical is the random.Provider which uses math/rand and is COMPLETELY UNSAFE FOR PRODUCTION IN MOST SITUATIONS.
|
||||
// Use random.Cryptographical instead.
|
||||
type Mathematical struct{}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *Mathematical) Read(p []byte) (n int, err error) {
|
||||
return rand.Read(p) //nolint:gosec
|
||||
}
|
||||
|
||||
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function returns it.
|
||||
func (r *Mathematical) BytesErr() (data []byte, err error) {
|
||||
data = make([]byte, DefaultN)
|
||||
|
||||
if _, err = rand.Read(data); err != nil { //nolint:gosec
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
|
||||
func (r *Mathematical) Bytes() (data []byte) {
|
||||
data, _ = r.BytesErr()
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
|
||||
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
|
||||
// returns it.
|
||||
func (r *Mathematical) BytesCustomErr(n int, charset []byte) (data []byte, err error) {
|
||||
if n < 1 {
|
||||
n = DefaultN
|
||||
}
|
||||
|
||||
data = make([]byte, n)
|
||||
|
||||
if _, err = rand.Read(data); err != nil { //nolint:gosec
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := len(charset)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
data[i] = charset[data[i]%byte(t)]
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
|
||||
func (r *Mathematical) StringCustomErr(n int, characters string) (data string, err error) {
|
||||
var d []byte
|
||||
|
||||
if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(d), nil
|
||||
}
|
||||
|
||||
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values.
|
||||
// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
|
||||
// ignores it.
|
||||
func (r *Mathematical) BytesCustom(n int, charset []byte) (data []byte) {
|
||||
data, _ = r.BytesCustomErr(n, charset)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// StringCustom is an overload of BytesCustom which takes a characters string and returns a string.
|
||||
func (r *Mathematical) StringCustom(n int, characters string) (data string) {
|
||||
return string(r.BytesCustom(n, []byte(characters)))
|
||||
}
|
||||
|
||||
// IntErr returns a random *big.Int error combination with a maximum of max.
|
||||
func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) {
|
||||
if max == nil {
|
||||
return nil, fmt.Errorf("max is required")
|
||||
}
|
||||
|
||||
if max.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("max must be 1 or more")
|
||||
}
|
||||
|
||||
return big.NewInt(int64(rand.Intn(max.Sign()))), nil //nolint:gosec
|
||||
}
|
||||
|
||||
// Int returns a random *big.Int with a maximum of max.
|
||||
func (r *Mathematical) Int(max *big.Int) (value *big.Int) {
|
||||
var err error
|
||||
|
||||
if value, err = r.IntErr(max); err != nil {
|
||||
return big.NewInt(-1)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// IntegerErr returns a random int error combination with a maximum of n.
|
||||
func (r *Mathematical) IntegerErr(n int) (output int, err error) {
|
||||
return r.Integer(n), nil
|
||||
}
|
||||
|
||||
// Integer returns a random int with a maximum of n.
|
||||
func (r *Mathematical) Integer(n int) int {
|
||||
return rand.Intn(n) //nolint:gosec
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// Provider of random functions and functionality.
|
||||
type Provider interface {
|
||||
io.Reader
|
||||
|
||||
// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function returns it.
|
||||
BytesErr() (data []byte, err error)
|
||||
|
||||
// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values
|
||||
// (including unreadable byte values). If an error is returned from the random read this function ignores it.
|
||||
Bytes() (data []byte)
|
||||
|
||||
// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided
|
||||
// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function
|
||||
// returns it.
|
||||
BytesCustomErr(n int, charset []byte) (data []byte, err error)
|
||||
|
||||
// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string.
|
||||
StringCustomErr(n int, characters string) (data string, err error)
|
||||
|
||||
// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided
|
||||
// values. If n is less than 1 then DefaultN is used instead.
|
||||
BytesCustom(n int, charset []byte) (data []byte)
|
||||
|
||||
// StringCustom is an overload of GenerateCustom which takes a characters string and returns a string.
|
||||
StringCustom(n int, characters string) (data string)
|
||||
|
||||
// IntErr returns a random *big.Int error combination with a maximum of max.
|
||||
IntErr(max *big.Int) (value *big.Int, err error)
|
||||
|
||||
// Int returns a random *big.Int with a maximum of max.
|
||||
Int(max *big.Int) (value *big.Int)
|
||||
|
||||
// IntegerErr returns a random int error combination with a maximum of n.
|
||||
IntegerErr(n int) (value int, err error)
|
||||
|
||||
// Integer returns a random integer with a maximum of n.
|
||||
Integer(n int) (value int)
|
||||
}
|
|
@ -15,8 +15,8 @@ import (
|
|||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
"github.com/authelia/authelia/v4/internal/templates"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// ServeTemplatedFile serves a templated version of a specified file,
|
||||
|
@ -46,7 +46,7 @@ func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middle
|
|||
ctx.SetContentTypeTextPlain()
|
||||
}
|
||||
|
||||
nonce := utils.RandomString(32, utils.CharSetAlphaNumeric)
|
||||
nonce := ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric)
|
||||
|
||||
switch {
|
||||
case ctx.Configuration.Server.Headers.CSPTemplate != "":
|
||||
|
@ -78,7 +78,7 @@ func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) mid
|
|||
if spec {
|
||||
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, tmplCSPSwagger)
|
||||
} else {
|
||||
nonce = utils.RandomString(32, utils.CharSetAlphaNumeric)
|
||||
nonce = ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric)
|
||||
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwaggerNonce, nonce, nonce))
|
||||
}
|
||||
|
||||
|
|
|
@ -97,40 +97,6 @@ const (
|
|||
timeUnixEpochAsMicrosoftNTEpoch uint64 = 116444736000000000
|
||||
)
|
||||
|
||||
const (
|
||||
// CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars.
|
||||
CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
// CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars.
|
||||
CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
// CharSetAlphabetic are literally just valid alphabetic printable ASCII chars.
|
||||
CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper
|
||||
|
||||
// CharSetNumeric are literally just valid numeric chars.
|
||||
CharSetNumeric = "0123456789"
|
||||
|
||||
// CharSetNumericHex are literally just valid hexadecimal printable ASCII chars.
|
||||
CharSetNumericHex = CharSetNumeric + "ABCDEF"
|
||||
|
||||
// CharSetSymbolic are literally just valid symbolic printable ASCII chars.
|
||||
CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||
|
||||
// CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters.
|
||||
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
|
||||
CharSetSymbolicRFC3986Unreserved = "-._~"
|
||||
|
||||
// CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars.
|
||||
CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric
|
||||
|
||||
// CharSetASCII are literally just valid printable ASCII chars.
|
||||
CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic
|
||||
|
||||
// CharSetRFC3986Unreserved are RFC3986 unreserved characters.
|
||||
// See https://www.rfc-editor.org/rfc/rfc3986#section-2.3.
|
||||
CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved
|
||||
)
|
||||
|
||||
var htmlEscaper = strings.NewReplacer(
|
||||
"&", "&",
|
||||
"<", "<",
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
)
|
||||
|
||||
const (
|
||||
testStringInput = "abcdefghijkl"
|
||||
)
|
||||
|
||||
var r = &random.Cryptographical{}
|
||||
|
|
|
@ -573,50 +573,3 @@ loop:
|
|||
|
||||
return extKeyUsage
|
||||
}
|
||||
|
||||
// RandomString returns a random string with a given length with values from the provided characters. When crypto is set
|
||||
// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
|
||||
// excluding when the task is time sensitive and would not benefit from extra randomness.
|
||||
func RandomString(n int, characters string) (randomString string) {
|
||||
return string(RandomBytes(n, characters))
|
||||
}
|
||||
|
||||
// RandomBytes returns a random []byte with a given length with values from the provided characters. When crypto is set
|
||||
// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
|
||||
// excluding when the task is time sensitive and would not benefit from extra randomness.
|
||||
func RandomBytes(n int, characters string) (bytes []byte) {
|
||||
bytes = make([]byte, n)
|
||||
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
for i, b := range bytes {
|
||||
bytes[i] = characters[b%byte(len(characters))]
|
||||
}
|
||||
|
||||
return bytes
|
||||
}
|
||||
|
||||
func RandomInt(n int) (int, error) {
|
||||
if n <= 0 {
|
||||
return 0, fmt.Errorf("n must be more than 0")
|
||||
}
|
||||
|
||||
max := big.NewInt(int64(n))
|
||||
|
||||
if !max.IsUint64() {
|
||||
return 0, fmt.Errorf("generated max is negative")
|
||||
}
|
||||
|
||||
value, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
output := int(value.Int64())
|
||||
|
||||
if output < 0 {
|
||||
return 0, fmt.Errorf("generated number is too big for int")
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/random"
|
||||
)
|
||||
|
||||
func TestShouldHashString(t *testing.T) {
|
||||
|
@ -22,7 +24,7 @@ func TestShouldHashString(t *testing.T) {
|
|||
assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum)
|
||||
assert.NotEqual(t, sum, anotherSum)
|
||||
|
||||
randomInput := RandomString(40, CharSetAlphaNumeric)
|
||||
randomInput := r.StringCustom(40, random.CharSetAlphaNumeric)
|
||||
randomSum := HashSHA256FromString(randomInput)
|
||||
|
||||
assert.NotEqual(t, randomSum, sum)
|
||||
|
@ -38,7 +40,7 @@ func TestShouldHashPath(t *testing.T) {
|
|||
err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(RandomString(40, CharSetAlphaNumeric)+"\n"), 0600)
|
||||
err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(r.StringCustom(40, random.CharSetAlphaNumeric)+"\n"), 0600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile"))
|
||||
|
|
|
@ -53,13 +53,6 @@ func TestStringJoinDelimitedEscaped(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestShouldNotGenerateSameRandomString(t *testing.T) {
|
||||
randomStringOne := RandomString(10, CharSetAlphaNumeric)
|
||||
randomStringTwo := RandomString(10, CharSetAlphaNumeric)
|
||||
|
||||
assert.NotEqual(t, randomStringOne, randomStringTwo)
|
||||
}
|
||||
|
||||
func TestShouldDetectAlphaNumericString(t *testing.T) {
|
||||
assert.True(t, IsStringAlphaNumeric("abc"))
|
||||
assert.True(t, IsStringAlphaNumeric("abc123"))
|
||||
|
|
Loading…
Reference in New Issue