diff --git a/go.mod b/go.mod index fabdf74c0..390edcd5b 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,6 @@ require ( github.com/trustelem/zxcvbn v1.0.1 github.com/valyala/fasthttp v1.43.0 github.com/wneessen/go-mail v0.3.7 - golang.org/x/net v0.5.0 golang.org/x/sync v0.1.0 golang.org/x/term v0.4.0 golang.org/x/text v0.6.0 @@ -110,6 +109,7 @@ require ( github.com/ysmood/leakless v0.8.0 // indirect golang.org/x/crypto v0.1.0 // indirect golang.org/x/mod v0.6.0 // indirect + golang.org/x/net v0.5.0 // indirect golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b // indirect golang.org/x/sys v0.4.0 // indirect golang.org/x/tools v0.2.0 // indirect diff --git a/internal/commands/context.go b/internal/commands/context.go index 737ab9ce0..a0f2df2b2 100644 --- a/internal/commands/context.go +++ b/internal/commands/context.go @@ -1,6 +1,7 @@ package commands import ( + "context" "crypto/x509" "fmt" "os" @@ -8,7 +9,6 @@ import ( "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/pflag" - "golang.org/x/net/context" "golang.org/x/sync/errgroup" "github.com/authelia/authelia/v4/internal/authentication" @@ -22,6 +22,7 @@ import ( "github.com/authelia/authelia/v4/internal/notification" "github.com/authelia/authelia/v4/internal/ntp" "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/storage" @@ -43,7 +44,10 @@ func NewCmdCtx() *CmdCtx { cancel: cancel, group: group, log: logging.Logger(), - config: &schema.Configuration{}, + providers: middlewares.Providers{ + Random: &random.Cryptographical{}, + }, + config: &schema.Configuration{}, } } @@ -139,48 +143,43 @@ func (ctx *CmdCtx) LoadProviders() (warns, errs []error) { return warns, errs } - storage := getStorageProvider(ctx) + ctx.providers.StorageProvider = getStorageProvider(ctx) - providers := middlewares.Providers{ - Authorizer: authorization.NewAuthorizer(ctx.config), - NTP: ntp.NewProvider(&ctx.config.NTP), - PasswordPolicy: middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy), - Regulator: regulation.NewRegulator(ctx.config.Regulation, storage, utils.RealClock{}), - SessionProvider: session.NewProvider(ctx.config.Session, ctx.trusted), - StorageProvider: storage, - TOTP: totp.NewTimeBasedProvider(ctx.config.TOTP), - } + ctx.providers.Authorizer = authorization.NewAuthorizer(ctx.config) + ctx.providers.NTP = ntp.NewProvider(&ctx.config.NTP) + ctx.providers.PasswordPolicy = middlewares.NewPasswordPolicyProvider(ctx.config.PasswordPolicy) + ctx.providers.Regulator = regulation.NewRegulator(ctx.config.Regulation, ctx.providers.StorageProvider, utils.RealClock{}) + ctx.providers.SessionProvider = session.NewProvider(ctx.config.Session, ctx.trusted) + ctx.providers.TOTP = totp.NewTimeBasedProvider(ctx.config.TOTP) var err error switch { case ctx.config.AuthenticationBackend.File != nil: - providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File) + ctx.providers.UserProvider = authentication.NewFileUserProvider(ctx.config.AuthenticationBackend.File) case ctx.config.AuthenticationBackend.LDAP != nil: - providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted) + ctx.providers.UserProvider = authentication.NewLDAPUserProvider(ctx.config.AuthenticationBackend, ctx.trusted) } - if providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil { + if ctx.providers.Templates, err = templates.New(templates.Config{EmailTemplatesPath: ctx.config.Notifier.TemplatePath}); err != nil { errs = append(errs, err) } switch { case ctx.config.Notifier.SMTP != nil: - providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted) + ctx.providers.Notifier = notification.NewSMTPNotifier(ctx.config.Notifier.SMTP, ctx.trusted) case ctx.config.Notifier.FileSystem != nil: - providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem) + ctx.providers.Notifier = notification.NewFileNotifier(*ctx.config.Notifier.FileSystem) } - if providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, storage); err != nil { + if ctx.providers.OpenIDConnect, err = oidc.NewOpenIDConnectProvider(ctx.config.IdentityProviders.OIDC, ctx.providers.StorageProvider); err != nil { errs = append(errs, err) } if ctx.config.Telemetry.Metrics.Enabled { - providers.Metrics = metrics.NewPrometheus() + ctx.providers.Metrics = metrics.NewPrometheus() } - ctx.providers = providers - return warns, errs } diff --git a/internal/commands/crypto.go b/internal/commands/crypto.go index 6cad8f23f..74c8d4099 100644 --- a/internal/commands/crypto.go +++ b/internal/commands/crypto.go @@ -3,7 +3,6 @@ package commands import ( "crypto/ecdsa" "crypto/ed25519" - "crypto/rand" "crypto/rsa" "crypto/x509" "fmt" @@ -262,7 +261,7 @@ func (ctx *CmdCtx) CryptoGenerateRunE(cmd *cobra.Command, args []string) (err er privateKey any ) - if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil { + if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil { return err } @@ -279,7 +278,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string) privateKey any ) - if privateKey, err = cryptoGenPrivateKeyFromCmd(cmd); err != nil { + if privateKey, err = ctx.cryptoGenPrivateKeyFromCmd(cmd); err != nil { return err } @@ -326,7 +325,7 @@ func (ctx *CmdCtx) CryptoCertificateRequestRunE(cmd *cobra.Command, _ []string) b.Reset() - if csr, err = x509.CreateCertificateRequest(rand.Reader, template, privateKey); err != nil { + if csr, err = x509.CreateCertificateRequest(ctx.providers.Random, template, privateKey); err != nil { return fmt.Errorf("failed to create certificate request: %w", err) } @@ -366,7 +365,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string, signatureKey = caPrivateKey } - if template, err = cryptoGetCertificateFromCmd(cmd); err != nil { + if template, err = ctx.cryptoGetCertificateFromCmd(cmd); err != nil { return err } @@ -423,7 +422,7 @@ func (ctx *CmdCtx) CryptoCertificateGenerateRunE(cmd *cobra.Command, _ []string, b.Reset() - if certificate, err = x509.CreateCertificate(rand.Reader, template, parent, publicKey, signatureKey); err != nil { + if certificate, err = x509.CreateCertificate(ctx.providers.Random, template, parent, publicKey, signatureKey); err != nil { return fmt.Errorf("failed to create certificate: %w", err) } diff --git a/internal/commands/crypto_helper.go b/internal/commands/crypto_helper.go index 1827970ed..17ee4060f 100644 --- a/internal/commands/crypto_helper.go +++ b/internal/commands/crypto_helper.go @@ -4,7 +4,6 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" - "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" @@ -130,7 +129,7 @@ func cryptoGetWritePathsFromCmd(cmd *cobra.Command) (privateKey, publicKey strin return filepath.Join(dir, private), filepath.Join(dir, public), nil } -func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) { +func (ctx *CmdCtx) cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) { switch cmd.Parent().Use { case cmdUseRSA: var ( @@ -141,7 +140,7 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) return nil, err } - if privateKey, err = rsa.GenerateKey(rand.Reader, bits); err != nil { + if privateKey, err = rsa.GenerateKey(ctx.providers.Random, bits); err != nil { return nil, fmt.Errorf("generating RSA private key resulted in an error: %w", err) } case cmdUseECDSA: @@ -158,11 +157,11 @@ func cryptoGenPrivateKeyFromCmd(cmd *cobra.Command) (privateKey any, err error) return nil, fmt.Errorf("invalid curve '%s' was specified: curve must be P224, P256, P384, or P521", curveStr) } - if privateKey, err = ecdsa.GenerateKey(curve, rand.Reader); err != nil { + if privateKey, err = ecdsa.GenerateKey(curve, ctx.providers.Random); err != nil { return nil, fmt.Errorf("generating ECDSA private key resulted in an error: %w", err) } case cmdUseEd25519: - if _, privateKey, err = ed25519.GenerateKey(rand.Reader); err != nil { + if _, privateKey, err = ed25519.GenerateKey(ctx.providers.Random); err != nil { return nil, fmt.Errorf("generating Ed25519 private key resulted in an error: %w", err) } } @@ -336,7 +335,7 @@ func cryptoGetSubjectFromCmd(cmd *cobra.Command) (subject *pkix.Name, err error) }, nil } -func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) { +func (ctx *CmdCtx) cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certificate, err error) { var ( ca bool subject *pkix.Name @@ -378,7 +377,7 @@ func cryptoGetCertificateFromCmd(cmd *cobra.Command) (certificate *x509.Certific serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - if serialNumber, err = rand.Int(rand.Reader, serialNumberLimit); err != nil { + if serialNumber, err = ctx.providers.Random.IntErr(serialNumberLimit); err != nil { return nil, fmt.Errorf("failed to generate serial number: %w", err) } diff --git a/internal/commands/storage_run.go b/internal/commands/storage_run.go index d6ce0632d..fb39bd7f9 100644 --- a/internal/commands/storage_run.go +++ b/internal/commands/storage_run.go @@ -18,6 +18,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/validator" "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/totp" "github.com/authelia/authelia/v4/internal/utils" @@ -983,7 +984,8 @@ func (ctx *CmdCtx) StorageUserTOTPExportPNGRunE(cmd *cobra.Command, _ []string) } if dir == "" { - dir = utils.RandomString(8, utils.CharSetAlphaNumeric) + rand := &random.Cryptographical{} + dir = rand.StringCustom(8, random.CharSetAlphaNumeric) } if _, err = os.Stat(dir); !os.IsNotExist(err) { diff --git a/internal/commands/util.go b/internal/commands/util.go index 3142088e7..4950c271b 100644 --- a/internal/commands/util.go +++ b/internal/commands/util.go @@ -14,7 +14,7 @@ import ( "golang.org/x/term" "github.com/authelia/authelia/v4/internal/configuration" - "github.com/authelia/authelia/v4/internal/utils" + "github.com/authelia/authelia/v4/internal/random" ) func recoverErr(i any) error { @@ -77,29 +77,29 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar switch c { case "ascii": - charset = utils.CharSetASCII + charset = random.CharSetASCII case "alphanumeric": - charset = utils.CharSetAlphaNumeric + charset = random.CharSetAlphaNumeric case "alphanumeric-lower": - charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric + charset = random.CharSetAlphabeticLower + random.CharSetNumeric case "alphanumeric-upper": - charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric + charset = random.CharSetAlphabeticUpper + random.CharSetNumeric case "alphabetic": - charset = utils.CharSetAlphabetic + charset = random.CharSetAlphabetic case "alphabetic-lower": - charset = utils.CharSetAlphabeticLower + charset = random.CharSetAlphabeticLower case "alphabetic-upper": - charset = utils.CharSetAlphabeticUpper + charset = random.CharSetAlphabeticUpper case "numeric-hex": - charset = utils.CharSetNumericHex + charset = random.CharSetNumericHex case "numeric": - charset = utils.CharSetNumeric + charset = random.CharSetNumeric case "rfc3986": - charset = utils.CharSetRFC3986Unreserved + charset = random.CharSetRFC3986Unreserved case "rfc3986-lower": - charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved + charset = random.CharSetAlphabeticLower + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved case "rfc3986-upper": - charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved + charset = random.CharSetAlphabeticUpper + random.CharSetNumeric + random.CharSetSymbolicRFC3986Unreserved default: return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c) } @@ -109,7 +109,9 @@ func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameChar } } - return utils.RandomString(n, charset), nil + rand := &random.Cryptographical{} + + return rand.StringCustom(n, charset), nil } func termReadConfirmation(flags *pflag.FlagSet, name, prompt, confirmation string) (confirmed bool, err error) { diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go index 3149c10a7..57ac0f613 100644 --- a/internal/middlewares/authelia_context_test.go +++ b/internal/middlewares/authelia_context_test.go @@ -13,6 +13,7 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/session" ) @@ -128,6 +129,7 @@ func TestShouldCallNextWithAutheliaCtx(t *testing.T) { providers := middlewares.Providers{ UserProvider: userProvider, SessionProvider: sessionProvider, + Random: random.NewMathematical(), } nextCalled := false diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index bb81141fa..2831eef3b 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/valyala/fasthttp" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/mocks" @@ -73,8 +74,8 @@ func TestShouldFailSendingAnEmail(t *testing.T) { defer mock.Close() mock.Ctx.Configuration.JWTSecret = testJWTSecret - mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") - mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") + mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http") + mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host") mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). @@ -95,8 +96,8 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) mock.Ctx.Configuration.JWTSecret = testJWTSecret - mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") - mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") + mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedProto, "http") + mock.Ctx.Request.Header.Add(fasthttp.HeaderXForwardedHost, "host") mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). diff --git a/internal/middlewares/timing_attack_delay.go b/internal/middlewares/timing_attack_delay.go index 163f8752c..b8b497d8c 100644 --- a/internal/middlewares/timing_attack_delay.go +++ b/internal/middlewares/timing_attack_delay.go @@ -1,7 +1,6 @@ package middlewares import ( - "crypto/rand" "math" "math/big" "sync" @@ -62,7 +61,7 @@ func movingAverageIteration(value time.Duration, history int, successful bool, c } func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) { - randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs)) + randomDelayMs, err := ctx.Providers.Random.IntErr(big.NewInt(maxRandomMs)) if err != nil { return float64(maxRandomMs) } diff --git a/internal/middlewares/timing_attack_delay_test.go b/internal/middlewares/timing_attack_delay_test.go index d7428a1b0..976be04a8 100644 --- a/internal/middlewares/timing_attack_delay_test.go +++ b/internal/middlewares/timing_attack_delay_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/random" ) func TestTimingAttackDelayAverages(t *testing.T) { @@ -45,7 +46,12 @@ func TestTimingAttackDelayCalculations(t *testing.T) { avgExecDurationMs := 1000.0 expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds()) - ctx := &AutheliaCtx{Logger: logging.Logger().WithFields(logrus.Fields{})} + ctx := &AutheliaCtx{ + Logger: logging.Logger().WithFields(logrus.Fields{}), + Providers: Providers{ + Random: &random.Cryptographical{}, + }, + } for i := 0; i < 100; i++ { delay := calculateActualDelay(ctx, execDuration, avgExecDurationMs, 250, 85, false) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index f26e57fc7..14893a606 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -11,6 +11,7 @@ import ( "github.com/authelia/authelia/v4/internal/notification" "github.com/authelia/authelia/v4/internal/ntp" "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/storage" @@ -44,6 +45,7 @@ type Providers struct { Templates *templates.Provider TOTP totp.Provider PasswordPolicy PasswordPolicyProvider + Random random.Provider } // RequestHandler represents an Authelia request handler. diff --git a/internal/mocks/authelia_ctx.go b/internal/mocks/authelia_ctx.go index 26a4a876f..b7fef421c 100644 --- a/internal/mocks/authelia_ctx.go +++ b/internal/mocks/authelia_ctx.go @@ -16,6 +16,7 @@ import ( "github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/templates" @@ -34,6 +35,7 @@ type MockAutheliaCtx struct { StorageMock *MockStorage NotifierMock *MockNotifier TOTPMock *MockTOTP + RandomMock *MockRandom UserSession *session.UserSession @@ -98,6 +100,10 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx { mockAuthelia.TOTPMock = NewMockTOTP(mockAuthelia.Ctrl) providers.TOTP = mockAuthelia.TOTPMock + mockAuthelia.RandomMock = NewMockRandom(mockAuthelia.Ctrl) + + providers.Random = random.NewMathematical() + var err error if providers.Templates, err = templates.New(templates.Config{}); err != nil { diff --git a/internal/mocks/gen.go b/internal/mocks/gen.go index 4ec2697f3..3c54c9612 100644 --- a/internal/mocks/gen.go +++ b/internal/mocks/gen.go @@ -8,3 +8,4 @@ package mocks //go:generate mockgen -package mocks -destination totp.go -mock_names Provider=MockTOTP github.com/authelia/authelia/v4/internal/totp Provider //go:generate mockgen -package mocks -destination storage.go -mock_names Provider=MockStorage github.com/authelia/authelia/v4/internal/storage Provider //go:generate mockgen -package mocks -destination duo_api.go -mock_names API=MockAPI github.com/authelia/authelia/v4/internal/duo API +//go:generate mockgen -package mocks -destination random.go -mock_names Provider=MockRandom github.com/authelia/authelia/v4/internal/random Provider diff --git a/internal/mocks/random.go b/internal/mocks/random.go new file mode 100644 index 000000000..9bab56ddd --- /dev/null +++ b/internal/mocks/random.go @@ -0,0 +1,195 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/authelia/authelia/v4/internal/random (interfaces: Provider) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + big "math/big" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockRandom is a mock of Provider interface. +type MockRandom struct { + ctrl *gomock.Controller + recorder *MockRandomMockRecorder +} + +// MockRandomMockRecorder is the mock recorder for MockRandom. +type MockRandomMockRecorder struct { + mock *MockRandom +} + +// NewMockRandom creates a new mock instance. +func NewMockRandom(ctrl *gomock.Controller) *MockRandom { + mock := &MockRandom{ctrl: ctrl} + mock.recorder = &MockRandomMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRandom) EXPECT() *MockRandomMockRecorder { + return m.recorder +} + +// Bytes mocks base method. +func (m *MockRandom) Bytes() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Bytes") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Bytes indicates an expected call of Bytes. +func (mr *MockRandomMockRecorder) Bytes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockRandom)(nil).Bytes)) +} + +// BytesCustom mocks base method. +func (m *MockRandom) BytesCustom(arg0 int, arg1 []byte) []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BytesCustom", arg0, arg1) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// BytesCustom indicates an expected call of BytesCustom. +func (mr *MockRandomMockRecorder) BytesCustom(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustom", reflect.TypeOf((*MockRandom)(nil).BytesCustom), arg0, arg1) +} + +// BytesCustomErr mocks base method. +func (m *MockRandom) BytesCustomErr(arg0 int, arg1 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BytesCustomErr", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BytesCustomErr indicates an expected call of BytesCustomErr. +func (mr *MockRandomMockRecorder) BytesCustomErr(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesCustomErr", reflect.TypeOf((*MockRandom)(nil).BytesCustomErr), arg0, arg1) +} + +// BytesErr mocks base method. +func (m *MockRandom) BytesErr() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BytesErr") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BytesErr indicates an expected call of BytesErr. +func (mr *MockRandomMockRecorder) BytesErr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BytesErr", reflect.TypeOf((*MockRandom)(nil).BytesErr)) +} + +// Int mocks base method. +func (m *MockRandom) Int(arg0 *big.Int) *big.Int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Int", arg0) + ret0, _ := ret[0].(*big.Int) + return ret0 +} + +// Int indicates an expected call of Int. +func (mr *MockRandomMockRecorder) Int(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int", reflect.TypeOf((*MockRandom)(nil).Int), arg0) +} + +// IntErr mocks base method. +func (m *MockRandom) IntErr(arg0 *big.Int) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IntErr", arg0) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IntErr indicates an expected call of IntErr. +func (mr *MockRandomMockRecorder) IntErr(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntErr", reflect.TypeOf((*MockRandom)(nil).IntErr), arg0) +} + +// Integer mocks base method. +func (m *MockRandom) Integer(arg0 int) int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Integer", arg0) + ret0, _ := ret[0].(int) + return ret0 +} + +// Integer indicates an expected call of Integer. +func (mr *MockRandomMockRecorder) Integer(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Integer", reflect.TypeOf((*MockRandom)(nil).Integer), arg0) +} + +// IntegerErr mocks base method. +func (m *MockRandom) IntegerErr(arg0 int) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IntegerErr", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IntegerErr indicates an expected call of IntegerErr. +func (mr *MockRandomMockRecorder) IntegerErr(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntegerErr", reflect.TypeOf((*MockRandom)(nil).IntegerErr), arg0) +} + +// Read mocks base method. +func (m *MockRandom) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockRandomMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockRandom)(nil).Read), arg0) +} + +// StringCustom mocks base method. +func (m *MockRandom) StringCustom(arg0 int, arg1 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StringCustom", arg0, arg1) + ret0, _ := ret[0].(string) + return ret0 +} + +// StringCustom indicates an expected call of StringCustom. +func (mr *MockRandomMockRecorder) StringCustom(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustom", reflect.TypeOf((*MockRandom)(nil).StringCustom), arg0, arg1) +} + +// StringCustomErr mocks base method. +func (m *MockRandom) StringCustomErr(arg0 int, arg1 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StringCustomErr", arg0, arg1) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StringCustomErr indicates an expected call of StringCustomErr. +func (mr *MockRandomMockRecorder) StringCustomErr(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StringCustomErr", reflect.TypeOf((*MockRandom)(nil).StringCustomErr), arg0, arg1) +} diff --git a/internal/notification/smtp_notifier.go b/internal/notification/smtp_notifier.go index 7afda8426..c212b2c9c 100644 --- a/internal/notification/smtp_notifier.go +++ b/internal/notification/smtp_notifier.go @@ -14,6 +14,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/templates" "github.com/authelia/authelia/v4/internal/utils" ) @@ -64,6 +65,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce return &SMTPNotifier{ config: config, domain: domain, + random: &random.Cryptographical{}, tls: utils.NewTLSConfig(config.TLS, certPool), log: logging.Logger(), opts: opts, @@ -74,6 +76,7 @@ func NewSMTPNotifier(config *schema.SMTPNotifierConfiguration, certPool *x509.Ce type SMTPNotifier struct { config *schema.SMTPNotifierConfiguration domain string + random random.Provider tls *tls.Config log *logrus.Logger opts []gomail.Option @@ -104,10 +107,10 @@ func (n *SMTPNotifier) StartupCheck() (err error) { func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject string, et *templates.EmailTemplate, data any) (err error) { msg := gomail.NewMsg( gomail.WithMIMEVersion(gomail.Mime10), - gomail.WithBoundary(utils.RandomString(30, utils.CharSetAlphaNumeric)), + gomail.WithBoundary(n.random.StringCustom(30, random.CharSetAlphaNumeric)), ) - setMessageID(msg, n.domain) + n.setMessageID(msg, n.domain) if err = msg.From(n.config.Sender.String()); err != nil { return fmt.Errorf("notifier: smtp: failed to set from address: %w", err) @@ -161,10 +164,10 @@ func (n *SMTPNotifier) Send(ctx context.Context, recipient mail.Address, subject return nil } -func setMessageID(msg *gomail.Msg, domain string) { - rn, _ := utils.RandomInt(100000000) - rm, _ := utils.RandomInt(10000) - rs := utils.RandomString(17, utils.CharSetAlphaNumeric) +func (n *SMTPNotifier) setMessageID(msg *gomail.Msg, domain string) { + rn := n.random.Integer(100000000) + rm := n.random.Integer(10000) + rs := n.random.StringCustom(17, random.CharSetAlphaNumeric) pid := os.Getpid() + rm msg.SetMessageIDWithValue(fmt.Sprintf("%d.%d%d.%s@%s", pid, rn, rm, rs, domain)) diff --git a/internal/random/const.go b/internal/random/const.go new file mode 100644 index 000000000..4a171f17f --- /dev/null +++ b/internal/random/const.go @@ -0,0 +1,43 @@ +package random + +const ( + // DefaultN is the default value of n. + DefaultN = 72 +) + +const ( + // CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars. + CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz" + + // CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars. + CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + // CharSetAlphabetic are literally just valid alphabetic printable ASCII chars. + CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper + + // CharSetNumeric are literally just valid numeric chars. + CharSetNumeric = "0123456789" + + // CharSetNumericHex are literally just valid hexadecimal printable ASCII chars. + CharSetNumericHex = CharSetNumeric + "ABCDEF" + + // CharSetSymbolic are literally just valid symbolic printable ASCII chars. + CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + + // CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters. + // See https://www.rfc-editor.org/rfc/rfc3986#section-2.3. + CharSetSymbolicRFC3986Unreserved = "-._~" + + // CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars. + CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric + + // CharSetASCII are literally just valid printable ASCII chars. + CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic + + // CharSetRFC3986Unreserved are RFC3986 unreserved characters. + // See https://www.rfc-editor.org/rfc/rfc3986#section-2.3. + CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved + + // CharSetUnambiguousUpper are a set of unambiguous uppercase characters. + CharSetUnambiguousUpper = "ABCDEFGHJKLMNOPQRTUVWYXZ2346789" +) diff --git a/internal/random/crypto.go b/internal/random/crypto.go new file mode 100644 index 000000000..bff1e0dce --- /dev/null +++ b/internal/random/crypto.go @@ -0,0 +1,136 @@ +package random + +import ( + "crypto/rand" + "fmt" + "io" + "math/big" +) + +// Cryptographical is the production random.Provider which uses crypto/rand. +type Cryptographical struct{} + +// Read implements the io.Reader interface. +func (r *Cryptographical) Read(p []byte) (n int, err error) { + return io.ReadFull(rand.Reader, p) +} + +// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values +// (including unreadable byte values). If an error is returned from the random read this function returns it. +func (r *Cryptographical) BytesErr() (data []byte, err error) { + data = make([]byte, DefaultN) + + _, err = rand.Read(data) + + return data, err +} + +// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values +// (including unreadable byte values). If an error is returned from the random read this function ignores it. +func (r *Cryptographical) Bytes() (data []byte) { + data, _ = r.BytesErr() + + return data +} + +// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided +// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function +// returns it. +func (r *Cryptographical) BytesCustomErr(n int, charset []byte) (data []byte, err error) { + if n < 1 { + n = DefaultN + } + + data = make([]byte, n) + + if _, err = rand.Read(data); err != nil { + return nil, err + } + + t := len(charset) + + for i := 0; i < n; i++ { + data[i] = charset[data[i]%byte(t)] + } + + return data, nil +} + +// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string. +func (r *Cryptographical) StringCustomErr(n int, characters string) (data string, err error) { + var d []byte + + if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil { + return "", err + } + + return string(d), nil +} + +// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values. +// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function +// ignores it. +func (r *Cryptographical) BytesCustom(n int, charset []byte) (data []byte) { + data, _ = r.BytesCustomErr(n, charset) + + return data +} + +// StringCustom is an overload of BytesCustom which takes a characters string and returns a string. +func (r *Cryptographical) StringCustom(n int, characters string) (data string) { + return string(r.BytesCustom(n, []byte(characters))) +} + +// IntErr returns a random *big.Int error combination with a maximum of max. +func (r *Cryptographical) IntErr(max *big.Int) (value *big.Int, err error) { + if max == nil { + return nil, fmt.Errorf("max is required") + } + + if max.Sign() <= 0 { + return nil, fmt.Errorf("max must be 1 or more") + } + + return rand.Int(rand.Reader, max) +} + +// Int returns a random *big.Int with a maximum of max. +func (r *Cryptographical) Int(max *big.Int) (value *big.Int) { + var err error + + if value, err = r.IntErr(max); err != nil { + return big.NewInt(-1) + } + + return value +} + +// IntegerErr returns a random int error combination with a maximum of n. +func (r *Cryptographical) IntegerErr(n int) (value int, err error) { + if n <= 0 { + return 0, fmt.Errorf("n must be more than 0") + } + + max := big.NewInt(int64(n)) + + var result *big.Int + + if result, err = r.IntErr(max); err != nil { + return 0, err + } + + value = int(result.Int64()) + + if value < 0 { + return 0, fmt.Errorf("generated number is too big for int") + } + + return value, nil +} + +// Integer returns a random int with a maximum of n. +func (r *Cryptographical) Integer(n int) (value int) { + value, _ = r.IntegerErr(n) + + return value +} diff --git a/internal/random/math.go b/internal/random/math.go new file mode 100644 index 000000000..d95a2d7f7 --- /dev/null +++ b/internal/random/math.go @@ -0,0 +1,126 @@ +package random + +import ( + "fmt" + "math/big" + "math/rand" + "time" +) + +// NewMathematical runs rand.Seed with the current time and returns a random.Provider, specifically *random.Mathematical. +func NewMathematical() *Mathematical { + rand.Seed(time.Now().UnixNano()) + + return &Mathematical{} +} + +// Mathematical is the random.Provider which uses math/rand and is COMPLETELY UNSAFE FOR PRODUCTION IN MOST SITUATIONS. +// Use random.Cryptographical instead. +type Mathematical struct{} + +// Read implements the io.Reader interface. +func (r *Mathematical) Read(p []byte) (n int, err error) { + return rand.Read(p) //nolint:gosec +} + +// BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values +// (including unreadable byte values). If an error is returned from the random read this function returns it. +func (r *Mathematical) BytesErr() (data []byte, err error) { + data = make([]byte, DefaultN) + + if _, err = rand.Read(data); err != nil { //nolint:gosec + return nil, err + } + + return data, nil +} + +// Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values +// (including unreadable byte values). If an error is returned from the random read this function ignores it. +func (r *Mathematical) Bytes() (data []byte) { + data, _ = r.BytesErr() + + return data +} + +// BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided +// values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function +// returns it. +func (r *Mathematical) BytesCustomErr(n int, charset []byte) (data []byte, err error) { + if n < 1 { + n = DefaultN + } + + data = make([]byte, n) + + if _, err = rand.Read(data); err != nil { //nolint:gosec + return nil, err + } + + t := len(charset) + + for i := 0; i < n; i++ { + data[i] = charset[data[i]%byte(t)] + } + + return data, nil +} + +// StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string. +func (r *Mathematical) StringCustomErr(n int, characters string) (data string, err error) { + var d []byte + + if d, err = r.BytesCustomErr(n, []byte(characters)); err != nil { + return "", err + } + + return string(d), nil +} + +// BytesCustom returns random data as bytes with n length and can contain only byte values from the provided values. +// If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function +// ignores it. +func (r *Mathematical) BytesCustom(n int, charset []byte) (data []byte) { + data, _ = r.BytesCustomErr(n, charset) + + return data +} + +// StringCustom is an overload of BytesCustom which takes a characters string and returns a string. +func (r *Mathematical) StringCustom(n int, characters string) (data string) { + return string(r.BytesCustom(n, []byte(characters))) +} + +// IntErr returns a random *big.Int error combination with a maximum of max. +func (r *Mathematical) IntErr(max *big.Int) (value *big.Int, err error) { + if max == nil { + return nil, fmt.Errorf("max is required") + } + + if max.Sign() <= 0 { + return nil, fmt.Errorf("max must be 1 or more") + } + + return big.NewInt(int64(rand.Intn(max.Sign()))), nil //nolint:gosec +} + +// Int returns a random *big.Int with a maximum of max. +func (r *Mathematical) Int(max *big.Int) (value *big.Int) { + var err error + + if value, err = r.IntErr(max); err != nil { + return big.NewInt(-1) + } + + return value +} + +// IntegerErr returns a random int error combination with a maximum of n. +func (r *Mathematical) IntegerErr(n int) (output int, err error) { + return r.Integer(n), nil +} + +// Integer returns a random int with a maximum of n. +func (r *Mathematical) Integer(n int) int { + return rand.Intn(n) //nolint:gosec +} diff --git a/internal/random/provider.go b/internal/random/provider.go new file mode 100644 index 000000000..7eeda72ab --- /dev/null +++ b/internal/random/provider.go @@ -0,0 +1,46 @@ +package random + +import ( + "io" + "math/big" +) + +// Provider of random functions and functionality. +type Provider interface { + io.Reader + + // BytesErr returns random data as bytes with the standard random.DefaultN length and can contain any byte values + // (including unreadable byte values). If an error is returned from the random read this function returns it. + BytesErr() (data []byte, err error) + + // Bytes returns random data as bytes with the standard random.DefaultN length and can contain any byte values + // (including unreadable byte values). If an error is returned from the random read this function ignores it. + Bytes() (data []byte) + + // BytesCustomErr returns random data as bytes with n length and can contain only byte values from the provided + // values. If n is less than 1 then DefaultN is used instead. If an error is returned from the random read this function + // returns it. + BytesCustomErr(n int, charset []byte) (data []byte, err error) + + // StringCustomErr is an overload of BytesCustomWithErr which takes a characters string and returns a string. + StringCustomErr(n int, characters string) (data string, err error) + + // BytesCustom returns random data as bytes with n length and can contain only byte values from the provided + // values. If n is less than 1 then DefaultN is used instead. + BytesCustom(n int, charset []byte) (data []byte) + + // StringCustom is an overload of GenerateCustom which takes a characters string and returns a string. + StringCustom(n int, characters string) (data string) + + // IntErr returns a random *big.Int error combination with a maximum of max. + IntErr(max *big.Int) (value *big.Int, err error) + + // Int returns a random *big.Int with a maximum of max. + Int(max *big.Int) (value *big.Int) + + // IntegerErr returns a random int error combination with a maximum of n. + IntegerErr(n int) (value int, err error) + + // Integer returns a random integer with a maximum of n. + Integer(n int) (value int) +} diff --git a/internal/server/template.go b/internal/server/template.go index cecd9d2f1..2aa24560d 100644 --- a/internal/server/template.go +++ b/internal/server/template.go @@ -15,8 +15,8 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/random" "github.com/authelia/authelia/v4/internal/templates" - "github.com/authelia/authelia/v4/internal/utils" ) // ServeTemplatedFile serves a templated version of a specified file, @@ -46,7 +46,7 @@ func ServeTemplatedFile(t templates.Template, opts *TemplatedFileOptions) middle ctx.SetContentTypeTextPlain() } - nonce := utils.RandomString(32, utils.CharSetAlphaNumeric) + nonce := ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric) switch { case ctx.Configuration.Server.Headers.CSPTemplate != "": @@ -78,7 +78,7 @@ func ServeTemplatedOpenAPI(t templates.Template, opts *TemplatedFileOptions) mid if spec { ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, tmplCSPSwagger) } else { - nonce = utils.RandomString(32, utils.CharSetAlphaNumeric) + nonce = ctx.Providers.Random.StringCustom(32, random.CharSetAlphaNumeric) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwaggerNonce, nonce, nonce)) } diff --git a/internal/utils/const.go b/internal/utils/const.go index 2e65da318..cb02b5459 100644 --- a/internal/utils/const.go +++ b/internal/utils/const.go @@ -97,40 +97,6 @@ const ( timeUnixEpochAsMicrosoftNTEpoch uint64 = 116444736000000000 ) -const ( - // CharSetAlphabeticLower are literally just valid alphabetic lowercase printable ASCII chars. - CharSetAlphabeticLower = "abcdefghijklmnopqrstuvwxyz" - - // CharSetAlphabeticUpper are literally just valid alphabetic uppercase printable ASCII chars. - CharSetAlphabeticUpper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - - // CharSetAlphabetic are literally just valid alphabetic printable ASCII chars. - CharSetAlphabetic = CharSetAlphabeticLower + CharSetAlphabeticUpper - - // CharSetNumeric are literally just valid numeric chars. - CharSetNumeric = "0123456789" - - // CharSetNumericHex are literally just valid hexadecimal printable ASCII chars. - CharSetNumericHex = CharSetNumeric + "ABCDEF" - - // CharSetSymbolic are literally just valid symbolic printable ASCII chars. - CharSetSymbolic = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" - - // CharSetSymbolicRFC3986Unreserved are RFC3986 unreserved symbol characters. - // See https://www.rfc-editor.org/rfc/rfc3986#section-2.3. - CharSetSymbolicRFC3986Unreserved = "-._~" - - // CharSetAlphaNumeric are literally just valid alphanumeric printable ASCII chars. - CharSetAlphaNumeric = CharSetAlphabetic + CharSetNumeric - - // CharSetASCII are literally just valid printable ASCII chars. - CharSetASCII = CharSetAlphabetic + CharSetNumeric + CharSetSymbolic - - // CharSetRFC3986Unreserved are RFC3986 unreserved characters. - // See https://www.rfc-editor.org/rfc/rfc3986#section-2.3. - CharSetRFC3986Unreserved = CharSetAlphabetic + CharSetNumeric + CharSetSymbolicRFC3986Unreserved -) - var htmlEscaper = strings.NewReplacer( "&", "&", "<", "<", diff --git a/internal/utils/const_test.go b/internal/utils/const_test.go index 879485536..fb6116a95 100644 --- a/internal/utils/const_test.go +++ b/internal/utils/const_test.go @@ -1,5 +1,11 @@ package utils +import ( + "github.com/authelia/authelia/v4/internal/random" +) + const ( testStringInput = "abcdefghijkl" ) + +var r = &random.Cryptographical{} diff --git a/internal/utils/crypto.go b/internal/utils/crypto.go index c8b7451db..6d74fc8ea 100644 --- a/internal/utils/crypto.go +++ b/internal/utils/crypto.go @@ -573,50 +573,3 @@ loop: return extKeyUsage } - -// RandomString returns a random string with a given length with values from the provided characters. When crypto is set -// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true -// excluding when the task is time sensitive and would not benefit from extra randomness. -func RandomString(n int, characters string) (randomString string) { - return string(RandomBytes(n, characters)) -} - -// RandomBytes returns a random []byte with a given length with values from the provided characters. When crypto is set -// to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true -// excluding when the task is time sensitive and would not benefit from extra randomness. -func RandomBytes(n int, characters string) (bytes []byte) { - bytes = make([]byte, n) - - _, _ = rand.Read(bytes) - - for i, b := range bytes { - bytes[i] = characters[b%byte(len(characters))] - } - - return bytes -} - -func RandomInt(n int) (int, error) { - if n <= 0 { - return 0, fmt.Errorf("n must be more than 0") - } - - max := big.NewInt(int64(n)) - - if !max.IsUint64() { - return 0, fmt.Errorf("generated max is negative") - } - - value, err := rand.Int(rand.Reader, max) - if err != nil { - return 0, err - } - - output := int(value.Int64()) - - if output < 0 { - return 0, fmt.Errorf("generated number is too big for int") - } - - return output, nil -} diff --git a/internal/utils/hashing_test.go b/internal/utils/hashing_test.go index c6777189e..f814d1c9f 100644 --- a/internal/utils/hashing_test.go +++ b/internal/utils/hashing_test.go @@ -7,6 +7,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/authelia/authelia/v4/internal/random" ) func TestShouldHashString(t *testing.T) { @@ -22,7 +24,7 @@ func TestShouldHashString(t *testing.T) { assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum) assert.NotEqual(t, sum, anotherSum) - randomInput := RandomString(40, CharSetAlphaNumeric) + randomInput := r.StringCustom(40, random.CharSetAlphaNumeric) randomSum := HashSHA256FromString(randomInput) assert.NotEqual(t, randomSum, sum) @@ -38,7 +40,7 @@ func TestShouldHashPath(t *testing.T) { err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600) assert.NoError(t, err) - err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(RandomString(40, CharSetAlphaNumeric)+"\n"), 0600) + err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(r.StringCustom(40, random.CharSetAlphaNumeric)+"\n"), 0600) assert.NoError(t, err) sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile")) diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index 49d3a3f69..fe7c6742e 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -53,13 +53,6 @@ func TestStringJoinDelimitedEscaped(t *testing.T) { } } -func TestShouldNotGenerateSameRandomString(t *testing.T) { - randomStringOne := RandomString(10, CharSetAlphaNumeric) - randomStringTwo := RandomString(10, CharSetAlphaNumeric) - - assert.NotEqual(t, randomStringOne, randomStringTwo) -} - func TestShouldDetectAlphaNumericString(t *testing.T) { assert.True(t, IsStringAlphaNumeric("abc")) assert.True(t, IsStringAlphaNumeric("abc123"))