fix(session): use crypto/rand for session id generator (#2594)

This adjusts the session ID generator making it use it's own random function rather than using one from the utils lib. This allows us to utilize crypto/rand or math/rand interchangeably. Additionally refactor the utils.RandomString func.
pull/2599/head
James Elliott 2021-11-11 20:13:32 +11:00 committed by GitHub
parent 7d5a59098d
commit 7efcac6017
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 69 additions and 28 deletions

View File

@ -60,7 +60,7 @@ const (
) )
// HashingPossibleSaltCharacters represents valid hashing runes. // HashingPossibleSaltCharacters represents valid hashing runes.
var HashingPossibleSaltCharacters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/") var HashingPossibleSaltCharacters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/"
// ErrUserNotFound indicates the user wasn't found in the authentication backend. // ErrUserNotFound indicates the user wasn't found in the authentication backend.
var ErrUserNotFound = errors.New("user not found") var ErrUserNotFound = errors.New("user not found")

View File

@ -126,7 +126,7 @@ func HashPassword(password, salt string, algorithm CryptAlgo, iterations, memory
} }
if salt == "" { if salt == "" {
salt = crypt.Base64Encoding.EncodeToString([]byte(utils.RandomString(saltLength, HashingPossibleSaltCharacters))) salt = crypt.Base64Encoding.EncodeToString(utils.RandomBytes(saltLength, HashingPossibleSaltCharacters, true))
} }
settings = getCryptSettings(salt, algorithm, iterations, memory, parallelism, keyLength) settings = getCryptSettings(salt, algorithm, iterations, memory, parallelism, keyLength)

View File

@ -58,8 +58,7 @@ func TestArgon2idHashSaltValidValues(t *testing.T) {
var hash string var hash string
data := string(HashingPossibleSaltCharacters) datas := utils.SliceString(HashingPossibleSaltCharacters, 16)
datas := utils.SliceString(data, 16)
for _, salt := range datas { for _, salt := range datas {
hash, err = HashPassword("password", salt, HashingAlgorithmArgon2id, 1, 8, 1, 32, 16) hash, err = HashPassword("password", salt, HashingAlgorithmArgon2id, 1, 8, 1, 32, 16)
@ -74,8 +73,7 @@ func TestSHA512HashSaltValidValues(t *testing.T) {
var hash string var hash string
data := string(HashingPossibleSaltCharacters) datas := utils.SliceString(HashingPossibleSaltCharacters, 16)
datas := utils.SliceString(data, 16)
for _, salt := range datas { for _, salt := range datas {
hash, err = HashPassword("password", salt, HashingAlgorithmSHA512, 1000, 0, 0, 0, 16) hash, err = HashPassword("password", salt, HashingAlgorithmSHA512, 1000, 0, 0, 0, 16)

View File

@ -133,7 +133,7 @@ func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error
return err return err
} }
boundary := utils.RandomString(30, utils.AlphaNumericCharacters) boundary := utils.RandomString(30, utils.AlphaNumericCharacters, true)
now := time.Now() now := time.Now()

View File

@ -13,8 +13,6 @@ import (
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
var alphaNumericRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
// ServeTemplatedFile serves a templated version of a specified file, // ServeTemplatedFile serves a templated version of a specified file,
// this is utilised to pass information between the backend and frontend // this is utilised to pass information between the backend and frontend
// and generate a nonce to support a restrictive CSP while using material-ui. // and generate a nonce to support a restrictive CSP while using material-ui.
@ -55,7 +53,7 @@ func ServeTemplatedFile(publicDir, file, rememberMe, resetPassword, session, the
} }
baseURL := scheme + "://" + string(ctx.Request.Host()) + base + "/" baseURL := scheme + "://" + string(ctx.Request.Host()) + base + "/"
nonce := utils.RandomString(32, alphaNumericRunes) nonce := utils.RandomString(32, utils.AlphaNumericCharacters, true)
switch extension := filepath.Ext(file); extension { switch extension := filepath.Ext(file); extension {
case ".html": case ".html":

View File

@ -1,8 +1,13 @@
package session package session
const userSessionStorerKey = "UserSession" const (
testDomain = "example.com"
testExpiration = "40"
testName = "my_session"
testUsername = "john"
)
const testDomain = "example.com" const (
const testExpiration = "40" userSessionStorerKey = "UserSession"
const testName = "my_session" randomSessionChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_!#$%^*"
const testUsername = "john" )

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
@ -20,7 +21,15 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
config := session.NewDefaultConfig() config := session.NewDefaultConfig()
config.SessionIDGeneratorFunc = func() []byte { config.SessionIDGeneratorFunc = func() []byte {
return []byte(utils.RandomString(30, utils.AlphaNumericCharacters)) bytes := make([]byte, 32)
_, _ = rand.Read(bytes)
for i, b := range bytes {
bytes[i] = randomSessionChars[b%byte(len(randomSessionChars))]
}
return bytes
} }
// Override the cookie name. // Override the cookie name.
@ -47,7 +56,6 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
// Ignore the error as it will be handled by validator. // Ignore the error as it will be handled by validator.
config.Expiration, _ = utils.ParseDurationString(configuration.Expiration) config.Expiration, _ = utils.ParseDurationString(configuration.Expiration)
// TODO(c.michaud): Make this configurable by giving the list of IPs that are trustable.
config.IsSecureFunc = func(*fasthttp.RequestCtx) bool { config.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
return true return true
} }

View File

@ -56,8 +56,10 @@ var (
reDuration = regexp.MustCompile(`^(?P<Duration>[1-9]\d*?)(?P<Unit>[smhdwMy])?$`) reDuration = regexp.MustCompile(`^(?P<Duration>[1-9]\d*?)(?P<Unit>[smhdwMy])?$`)
) )
// AlphaNumericCharacters are literally just valid alphanumeric chars. var (
var AlphaNumericCharacters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") // AlphaNumericCharacters are literally just valid alphanumeric chars.
AlphaNumericCharacters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
var htmlEscaper = strings.NewReplacer( var htmlEscaper = strings.NewReplacer(
"&", "&amp;", "&", "&amp;",

View File

@ -23,7 +23,7 @@ func TestShouldHashString(t *testing.T) {
assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum) assert.Equal(t, "ae448ac86c4e8e4dec645729708ef41873ae79c6dff84eff73360989487f08e5", anotherSum)
assert.NotEqual(t, sum, anotherSum) assert.NotEqual(t, sum, anotherSum)
randomInput := RandomString(40, AlphaNumericCharacters) randomInput := RandomString(40, AlphaNumericCharacters, false)
randomSum := HashSHA256FromString(randomInput) randomSum := HashSHA256FromString(randomInput)
assert.NotEqual(t, randomSum, sum) assert.NotEqual(t, randomSum, sum)
@ -40,7 +40,7 @@ func TestShouldHashPath(t *testing.T) {
err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600) err = os.WriteFile(filepath.Join(dir, "anotherfile"), []byte("another\n"), 0600)
assert.NoError(t, err) assert.NoError(t, err)
err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(RandomString(40, AlphaNumericCharacters)+"\n"), 0600) err = os.WriteFile(filepath.Join(dir, "randomfile"), []byte(RandomString(40, AlphaNumericCharacters, true)+"\n"), 0600)
assert.NoError(t, err) assert.NoError(t, err)
sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile")) sum, err := HashSHA256FromPath(filepath.Join(dir, "myfile"))

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
crand "crypto/rand"
"fmt" "fmt"
"math/rand" "math/rand"
"net/url" "net/url"
@ -139,19 +140,37 @@ func StringSlicesDelta(before, after []string) (added, removed []string) {
return added, removed return added, removed
} }
// RandomString generate a random string of n characters. // RandomString returns a random string with a given length with values from the provided characters. When crypto is set
func RandomString(n int, characters []rune) (randomString string) { // to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
rand.Seed(time.Now().UnixNano()) // excluding when the task is time sensitive and would not benefit from extra randomness.
func RandomString(n int, characters string, crypto bool) (randomString string) {
return string(RandomBytes(n, characters, crypto))
}
b := make([]rune, n) // RandomBytes returns a random []byte with a given length with values from the provided characters. When crypto is set
for i := range b { // to false we use math/rand and when it's set to true we use crypto/rand. The crypto option should always be set to true
b[i] = characters[rand.Intn(len(characters))] //nolint:gosec // Likely isn't necessary to use the more expensive crypto/rand for this utility func. // excluding when the task is time sensitive and would not benefit from extra randomness.
func RandomBytes(n int, characters string, crypto bool) (bytes []byte) {
bytes = make([]byte, n)
if crypto {
_, _ = crand.Read(bytes)
} else {
_, _ = rand.Read(bytes) //nolint:gosec // As this is an option when using this function it's not necessary to be concerned about this.
} }
return string(b) for i, b := range bytes {
bytes[i] = characters[b%byte(len(characters))]
}
return bytes
} }
// StringHTMLEscape escapes chars for a HTML body. // StringHTMLEscape escapes chars for a HTML body.
func StringHTMLEscape(input string) (output string) { func StringHTMLEscape(input string) (output string) {
return htmlEscaper.Replace(input) return htmlEscaper.Replace(input)
} }
func init() {
rand.Seed(time.Now().UnixNano())
}

View File

@ -7,6 +7,17 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestShouldNotGenerateSameRandomString(t *testing.T) {
randomStringOne := RandomString(10, AlphaNumericCharacters, false)
randomStringTwo := RandomString(10, AlphaNumericCharacters, false)
randomCryptoStringOne := RandomString(10, AlphaNumericCharacters, true)
randomCryptoStringTwo := RandomString(10, AlphaNumericCharacters, true)
assert.NotEqual(t, randomStringOne, randomStringTwo)
assert.NotEqual(t, randomCryptoStringOne, randomCryptoStringTwo)
}
func TestShouldDetectAlphaNumericString(t *testing.T) { func TestShouldDetectAlphaNumericString(t *testing.T) {
assert.True(t, IsStringAlphaNumeric("abc")) assert.True(t, IsStringAlphaNumeric("abc"))
assert.True(t, IsStringAlphaNumeric("abc123")) assert.True(t, IsStringAlphaNumeric("abc123"))