test: use test machinery to set env vars in tests (#4640)

This commit replaces `os.Setenv` with `t.Setenv` in tests. The environment variable is automatically restored to its original value when the test and all its subtests complete. Reference: https://pkg.go.dev/testing#T.Setenv

Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
pull/4643/head
Eng Zer Jun 2022-12-26 04:16:05 +08:00 committed by GitHub
parent 3b699b8604
commit 54afe925b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 81 deletions

View File

@ -61,7 +61,7 @@ func TestLoadXEnvCLIStringSliceValue(t *testing.T) {
} }
if tc.envValue != "" { if tc.envValue != "" {
require.NoError(t, os.Setenv(tc.envKey, tc.envValue)) t.Setenv(tc.envKey, tc.envValue)
} }
actual, actualResult, actualErr := loadXEnvCLIStringSliceValue(cmd, tc.envKey, tc.flag.Name) actual, actualResult, actualErr := loadXEnvCLIStringSliceValue(cmd, tc.envKey, tc.flag.Name)
@ -74,10 +74,6 @@ func TestLoadXEnvCLIStringSliceValue(t *testing.T) {
} else { } else {
assert.EqualError(t, actualErr, tc.expectedErr) assert.EqualError(t, actualErr, tc.expectedErr)
} }
if tc.envValue != "" {
require.NoError(t, os.Unsetenv(tc.envKey))
}
}) })
} }
} }

View File

@ -17,8 +17,6 @@ import (
) )
func TestShouldErrorSecretNotExist(t *testing.T) { func TestShouldErrorSecretNotExist(t *testing.T) {
testReset()
dir := t.TempDir() dir := t.TempDir()
testSetEnv(t, "JWT_SECRET_FILE", filepath.Join(dir, "jwt")) testSetEnv(t, "JWT_SECRET_FILE", filepath.Join(dir, "jwt"))
@ -73,8 +71,6 @@ func TestLoadShouldReturnErrWithoutSources(t *testing.T) {
} }
func TestShouldHaveNotifier(t *testing.T) { func TestShouldHaveNotifier(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "abc") testSetEnv(t, "SESSION_SECRET", "abc")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc")
testSetEnv(t, "JWT_SECRET", "abc") testSetEnv(t, "JWT_SECRET", "abc")
@ -90,8 +86,6 @@ func TestShouldHaveNotifier(t *testing.T) {
} }
func TestShouldValidateConfigurationWithEnv(t *testing.T) { func TestShouldValidateConfigurationWithEnv(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "abc") testSetEnv(t, "SESSION_SECRET", "abc")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc")
testSetEnv(t, "JWT_SECRET", "abc") testSetEnv(t, "JWT_SECRET", "abc")
@ -106,15 +100,13 @@ func TestShouldValidateConfigurationWithEnv(t *testing.T) {
} }
func TestShouldValidateConfigurationWithFilters(t *testing.T) { func TestShouldValidateConfigurationWithFilters(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "abc") testSetEnv(t, "SESSION_SECRET", "abc")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc")
testSetEnv(t, "JWT_SECRET", "abc") testSetEnv(t, "JWT_SECRET", "abc")
testSetEnv(t, "AUTHENTICATION_BACKEND_LDAP_PASSWORD", "abc") testSetEnv(t, "AUTHENTICATION_BACKEND_LDAP_PASSWORD", "abc")
_ = os.Setenv("SERVICES_SERVER", "10.10.10.10") t.Setenv("SERVICES_SERVER", "10.10.10.10")
_ = os.Setenv("ROOT_DOMAIN", "example.org") t.Setenv("ROOT_DOMAIN", "example.org")
val := schema.NewStructValidator() val := schema.NewStructValidator()
_, config, err := Load(val, NewDefaultSourcesFiltered([]string{"./test_resources/config.filtered.yml"}, NewFileFiltersDefault(), DefaultEnvPrefix, DefaultEnvDelimiter)...) _, config, err := Load(val, NewDefaultSourcesFiltered([]string{"./test_resources/config.filtered.yml"}, NewFileFiltersDefault(), DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -129,8 +121,6 @@ func TestShouldValidateConfigurationWithFilters(t *testing.T) {
} }
func TestShouldNotIgnoreInvalidEnvs(t *testing.T) { func TestShouldNotIgnoreInvalidEnvs(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "an env session secret") testSetEnv(t, "SESSION_SECRET", "an env session secret")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "an env storage mysql password") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "an env storage mysql password")
testSetEnv(t, "STORAGE_MYSQL", "a bad env") testSetEnv(t, "STORAGE_MYSQL", "a bad env")
@ -152,8 +142,6 @@ func TestShouldNotIgnoreInvalidEnvs(t *testing.T) {
} }
func TestShouldValidateAndRaiseErrorsOnNormalConfigurationAndSecret(t *testing.T) { func TestShouldValidateAndRaiseErrorsOnNormalConfigurationAndSecret(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "an env session secret") testSetEnv(t, "SESSION_SECRET", "an env session secret")
testSetEnv(t, "SESSION_SECRET_FILE", "./test_resources/example_secret") testSetEnv(t, "SESSION_SECRET_FILE", "./test_resources/example_secret")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "an env storage mysql password") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "an env storage mysql password")
@ -182,8 +170,6 @@ func TestShouldRaiseIOErrOnUnreadableFile(t *testing.T) {
t.Skip("skipping test due to being on windows") t.Skip("skipping test due to being on windows")
} }
testReset()
dir := t.TempDir() dir := t.TempDir()
assert.NoError(t, os.WriteFile(filepath.Join(dir, "myconf.yml"), []byte("server:\n port: 9091\n"), 0000)) assert.NoError(t, os.WriteFile(filepath.Join(dir, "myconf.yml"), []byte("server:\n port: 9091\n"), 0000))
@ -200,8 +186,6 @@ func TestShouldRaiseIOErrOnUnreadableFile(t *testing.T) {
} }
func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) { func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET_FILE", "./test_resources/example_secret") testSetEnv(t, "SESSION_SECRET_FILE", "./test_resources/example_secret")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD_FILE", "./test_resources/example_secret") testSetEnv(t, "STORAGE_MYSQL_PASSWORD_FILE", "./test_resources/example_secret")
testSetEnv(t, "JWT_SECRET_FILE", "./test_resources/example_secret") testSetEnv(t, "JWT_SECRET_FILE", "./test_resources/example_secret")
@ -223,8 +207,6 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
} }
func TestShouldLoadURLList(t *testing.T) { func TestShouldLoadURLList(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -241,8 +223,6 @@ func TestShouldLoadURLList(t *testing.T) {
} }
func TestShouldConfigureConsent(t *testing.T) { func TestShouldConfigureConsent(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -258,8 +238,6 @@ func TestShouldConfigureConsent(t *testing.T) {
} }
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) { func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
testReset()
testSetEnv(t, "SESSION_SECRET", "abc") testSetEnv(t, "SESSION_SECRET", "abc")
testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc") testSetEnv(t, "STORAGE_MYSQL_PASSWORD", "abc")
testSetEnv(t, "JWT_SECRET", "abc") testSetEnv(t, "JWT_SECRET", "abc")
@ -282,8 +260,6 @@ func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
} }
func TestShouldRaiseErrOnInvalidNotifierSMTPSender(t *testing.T) { func TestShouldRaiseErrOnInvalidNotifierSMTPSender(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, _, err := Load(val, NewDefaultSources([]string{"./test_resources/config_smtp_sender_invalid.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, _, err := Load(val, NewDefaultSources([]string{"./test_resources/config_smtp_sender_invalid.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -298,8 +274,6 @@ func TestShouldRaiseErrOnInvalidNotifierSMTPSender(t *testing.T) {
} }
func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) { func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_smtp_sender_blank.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_smtp_sender_blank.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -322,8 +296,6 @@ func TestShouldHandleErrInvalidatorWhenSMTPSenderBlank(t *testing.T) {
} }
func TestShouldDecodeSMTPSenderWithoutName(t *testing.T) { func TestShouldDecodeSMTPSenderWithoutName(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -339,8 +311,6 @@ func TestShouldDecodeSMTPSenderWithoutName(t *testing.T) {
} }
func TestShouldDecodeSMTPSenderWithName(t *testing.T) { func TestShouldDecodeSMTPSenderWithName(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_alt.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_alt.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -357,8 +327,6 @@ func TestShouldDecodeSMTPSenderWithName(t *testing.T) {
} }
func TestShouldParseRegex(t *testing.T) { func TestShouldParseRegex(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_domain_regex.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_domain_regex.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -389,8 +357,6 @@ func TestShouldParseRegex(t *testing.T) {
} }
func TestShouldErrOnParseInvalidRegex(t *testing.T) { func TestShouldErrOnParseInvalidRegex(t *testing.T) {
testReset()
val := schema.NewStructValidator() val := schema.NewStructValidator()
keys, _, err := Load(val, NewDefaultSources([]string{"./test_resources/config_domain_bad_regex.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...) keys, _, err := Load(val, NewDefaultSources([]string{"./test_resources/config_domain_bad_regex.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
@ -409,8 +375,6 @@ func TestShouldNotReadConfigurationOnFSAccessDenied(t *testing.T) {
t.Skip("skipping test due to being on windows") t.Skip("skipping test due to being on windows")
} }
testReset()
dir := t.TempDir() dir := t.TempDir()
cfg := filepath.Join(dir, "config.yml") cfg := filepath.Join(dir, "config.yml")
@ -426,8 +390,6 @@ func TestShouldNotReadConfigurationOnFSAccessDenied(t *testing.T) {
} }
func TestShouldLoadDirectoryConfiguration(t *testing.T) { func TestShouldLoadDirectoryConfiguration(t *testing.T) {
testReset()
dir := t.TempDir() dir := t.TempDir()
val := schema.NewStructValidator() val := schema.NewStructValidator()
@ -439,31 +401,7 @@ func TestShouldLoadDirectoryConfiguration(t *testing.T) {
} }
func testSetEnv(t *testing.T, key, value string) { func testSetEnv(t *testing.T, key, value string) {
assert.NoError(t, os.Setenv(DefaultEnvPrefix+key, value)) t.Setenv(DefaultEnvPrefix+key, value)
}
func testReset() {
testUnsetEnvName("STORAGE_MYSQL")
testUnsetEnvName("JWT_SECRET")
testUnsetEnvName("DUO_API_SECRET_KEY")
testUnsetEnvName("SESSION_SECRET")
testUnsetEnvName("AUTHENTICATION_BACKEND_LDAP_PASSWORD")
testUnsetEnvName("AUTHENTICATION_BACKEND_LDAP_URL")
testUnsetEnvName("NOTIFIER_SMTP_PASSWORD")
testUnsetEnvName("SESSION_REDIS_PASSWORD")
testUnsetEnvName("SESSION_REDIS_HIGH_AVAILABILITY_SENTINEL_PASSWORD")
testUnsetEnvName("STORAGE_MYSQL_PASSWORD")
testUnsetEnvName("STORAGE_POSTGRES_PASSWORD")
testUnsetEnvName("SERVER_TLS_KEY")
testUnsetEnvName("SERVER_PORT")
testUnsetEnvName("IDENTITY_PROVIDERS_OIDC_ISSUER_PRIVATE_KEY")
testUnsetEnvName("IDENTITY_PROVIDERS_OIDC_HMAC_SECRET")
testUnsetEnvName("STORAGE_ENCRYPTION_KEY")
}
func testUnsetEnvName(name string) {
_ = os.Unsetenv(DefaultEnvPrefix + name)
_ = os.Unsetenv(DefaultEnvPrefix + name + constSecretSuffix)
} }
func testCreateFile(path, value string, perm os.FileMode) (err error) { func testCreateFile(path, value string, perm os.FileMode) (err error) {

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"hash" "hash"
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -43,16 +42,12 @@ func TestFuncGetEnv(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
for key, value := range tc.have { for key, value := range tc.have {
assert.NoError(t, os.Setenv(key, value)) t.Setenv(key, value)
} }
for key, expected := range tc.expected { for key, expected := range tc.expected {
assert.Equal(t, expected, FuncGetEnv(key)) assert.Equal(t, expected, FuncGetEnv(key))
} }
for key := range tc.have {
assert.NoError(t, os.Unsetenv(key))
}
}) })
} }
} }
@ -86,14 +81,10 @@ func TestFuncExpandEnv(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
for key, value := range tc.env { for key, value := range tc.env {
assert.NoError(t, os.Setenv(key, value)) t.Setenv(key, value)
} }
assert.Equal(t, tc.expected, FuncExpandEnv(tc.have)) assert.Equal(t, tc.expected, FuncExpandEnv(tc.have))
for key := range tc.env {
assert.NoError(t, os.Unsetenv(key))
}
}) })
} }
} }