diff --git a/docs/content/en/configuration/security/access-control.md b/docs/content/en/configuration/security/access-control.md index 4be5546a2..011cec776 100644 --- a/docs/content/en/configuration/security/access-control.md +++ b/docs/content/en/configuration/security/access-control.md @@ -542,14 +542,14 @@ if they have a path of exactly `/api` or if they start with `/api/`. This means a match for that request. ```yaml -- domains: +- domain: - 'example.com' - '*.example.com' policy: bypass resources: - '^/api$' - '^/api/' -- domains: +- domain: - 'app.example.com' policy: two_factor ``` diff --git a/internal/configuration/validator/access_control.go b/internal/configuration/validator/access_control.go index 9509c4cab..994d7559c 100644 --- a/internal/configuration/validator/access_control.go +++ b/internal/configuration/validator/access_control.go @@ -90,9 +90,7 @@ func ValidateRules(config *schema.Configuration, validator *schema.StructValidat for i, rule := range config.AccessControl.Rules { rulePosition := i + 1 - if len(rule.Domains)+len(rule.DomainsRegex) == 0 { - validator.Push(fmt.Errorf(errFmtAccessControlRuleNoDomains, ruleDescriptor(rulePosition, rule))) - } + validateDomains(rulePosition, rule, validator) if !IsPolicyValid(rule.Policy) { validator.Push(fmt.Errorf(errFmtAccessControlRuleInvalidPolicy, ruleDescriptor(rulePosition, rule), rule.Policy)) @@ -125,6 +123,18 @@ func validateBypass(rulePosition int, rule schema.ACLRule, validator *schema.Str } } +func validateDomains(rulePosition int, rule schema.ACLRule, validator *schema.StructValidator) { + if len(rule.Domains)+len(rule.DomainsRegex) == 0 { + validator.Push(fmt.Errorf(errFmtAccessControlRuleNoDomains, ruleDescriptor(rulePosition, rule))) + } + + for i, domain := range rule.Domains { + if len(domain) > 1 && domain[0] == '*' && domain[1] != '.' { + validator.PushWarning(fmt.Errorf("access control: rule #%d: domain #%d: domain '%s' is ineffective and should probably be '%s' instead", rulePosition, i+1, domain, fmt.Sprintf("*.%s", domain[1:]))) + } + } +} + func validateNetworks(rulePosition int, rule schema.ACLRule, config schema.AccessControlConfiguration, validator *schema.StructValidator) { for _, network := range rule.Networks { if !IsNetworkValid(network) { diff --git a/internal/configuration/validator/access_control_test.go b/internal/configuration/validator/access_control_test.go index ae7dabb18..0671455a1 100644 --- a/internal/configuration/validator/access_control_test.go +++ b/internal/configuration/validator/access_control_test.go @@ -88,6 +88,22 @@ func (suite *AccessControl) TestShouldRaiseErrorInvalidNetworkGroupNetwork() { suite.Assert().EqualError(suite.validator.Errors()[0], "access control: networks: network group 'internal' is invalid: the network 'abc.def.ghi.jkl' is not a valid IP or CIDR notation") } +func (suite *AccessControl) TestShouldRaiseWarningOnBadDomain() { + suite.config.AccessControl.Rules = []schema.ACLRule{ + { + Domains: []string{"*example.com"}, + Policy: "one_factor", + }, + } + + ValidateRules(suite.config, suite.validator) + + suite.Assert().Len(suite.validator.Warnings(), 1) + suite.Require().Len(suite.validator.Errors(), 0) + + suite.Assert().EqualError(suite.validator.Warnings()[0], "access control: rule #1: domain #1: domain '*example.com' is ineffective and should probably be '*.example.com' instead") +} + func (suite *AccessControl) TestShouldRaiseErrorWithNoRulesDefined() { suite.config.AccessControl.Rules = []schema.ACLRule{} diff --git a/internal/configuration/validator/identity_providers.go b/internal/configuration/validator/identity_providers.go index 3bea1c63b..23a797153 100644 --- a/internal/configuration/validator/identity_providers.go +++ b/internal/configuration/validator/identity_providers.go @@ -12,11 +12,11 @@ import ( ) // ValidateIdentityProviders validates and updates the IdentityProviders configuration. -func ValidateIdentityProviders(config *schema.IdentityProvidersConfiguration, validator *schema.StructValidator) { - validateOIDC(config.OIDC, validator) +func ValidateIdentityProviders(config *schema.IdentityProvidersConfiguration, val *schema.StructValidator) { + validateOIDC(config.OIDC, val) } -func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { +func validateOIDC(config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { if config == nil { return } @@ -25,37 +25,37 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S switch { case config.IssuerPrivateKey == nil: - validator.Push(fmt.Errorf(errFmtOIDCNoPrivateKey)) + val.Push(fmt.Errorf(errFmtOIDCNoPrivateKey)) default: if config.IssuerCertificateChain.HasCertificates() { if !config.IssuerCertificateChain.EqualKey(config.IssuerPrivateKey) { - validator.Push(fmt.Errorf(errFmtOIDCCertificateMismatch)) + val.Push(fmt.Errorf(errFmtOIDCCertificateMismatch)) } if err := config.IssuerCertificateChain.Validate(); err != nil { - validator.Push(fmt.Errorf(errFmtOIDCCertificateChain, err)) + val.Push(fmt.Errorf(errFmtOIDCCertificateChain, err)) } } if config.IssuerPrivateKey.Size()*8 < 2048 { - validator.Push(fmt.Errorf(errFmtOIDCInvalidPrivateKeyBitSize, 2048, config.IssuerPrivateKey.Size()*8)) + val.Push(fmt.Errorf(errFmtOIDCInvalidPrivateKeyBitSize, 2048, config.IssuerPrivateKey.Size()*8)) } } if config.MinimumParameterEntropy != 0 && config.MinimumParameterEntropy < 8 { - validator.PushWarning(fmt.Errorf(errFmtOIDCServerInsecureParameterEntropy, config.MinimumParameterEntropy)) + val.PushWarning(fmt.Errorf(errFmtOIDCServerInsecureParameterEntropy, config.MinimumParameterEntropy)) } if config.EnforcePKCE != "never" && config.EnforcePKCE != "public_clients_only" && config.EnforcePKCE != "always" { - validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE)) + val.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE)) } - validateOIDCOptionsCORS(config, validator) + validateOIDCOptionsCORS(config, val) if len(config.Clients) == 0 { - validator.Push(fmt.Errorf(errFmtOIDCNoClientsConfigured)) + val.Push(fmt.Errorf(errFmtOIDCNoClientsConfigured)) } else { - validateOIDCClients(config, validator) + validateOIDCClients(config, val) } } @@ -91,26 +91,26 @@ func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validato validateOIDCOptionsCORSEndpoints(config, validator) } -func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { +func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { for _, origin := range config.CORS.AllowedOrigins { if origin.String() == "*" { if len(config.CORS.AllowedOrigins) != 1 { - validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard)) + val.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard)) } if config.CORS.AllowedOriginsFromClientRedirectURIs { - validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients)) + val.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients)) } continue } if origin.Path != "" { - validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path")) + val.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path")) } if origin.RawQuery != "" { - validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string")) + val.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string")) } } } @@ -132,16 +132,15 @@ func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema. } } -func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { +func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { for _, endpoint := range config.CORS.Endpoints { if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) { - validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '"))) + val.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '"))) } } } -//nolint:gocyclo // TODO: Refactor. -func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { +func validateOIDCClients(config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { invalidID, duplicateIDs := false, false var ids []string @@ -162,176 +161,179 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s if client.Public { if client.Secret != nil { - validator.Push(fmt.Errorf(errFmtOIDCClientPublicInvalidSecret, client.ID)) + val.Push(fmt.Errorf(errFmtOIDCClientPublicInvalidSecret, client.ID)) } } else { if client.Secret == nil { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSecret, client.ID)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSecret, client.ID)) } } if client.Policy == "" { config.Clients[c].Policy = schema.DefaultOpenIDConnectClientConfiguration.Policy } else if client.Policy != policyOneFactor && client.Policy != policyTwoFactor { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidPolicy, client.ID, client.Policy)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidPolicy, client.ID, client.Policy)) } - switch { - case utils.IsStringInSlice(client.ConsentMode, []string{"", "auto"}): - if client.ConsentPreConfiguredDuration != nil { - config.Clients[c].ConsentMode = oidc.ClientConsentModePreConfigured.String() - } else { - config.Clients[c].ConsentMode = oidc.ClientConsentModeExplicit.String() - } - case utils.IsStringInSlice(client.ConsentMode, validOIDCClientConsentModes): - break - default: - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidConsentMode, client.ID, strings.Join(append(validOIDCClientConsentModes, "auto"), "', '"), client.ConsentMode)) - } - - if client.ConsentPreConfiguredDuration == nil { - config.Clients[c].ConsentPreConfiguredDuration = schema.DefaultOpenIDConnectClientConfiguration.ConsentPreConfiguredDuration - } - - validateOIDCClientSectorIdentifier(client, validator) - validateOIDCClientScopes(c, config, validator) - validateOIDCClientGrantTypes(c, config, validator) - validateOIDCClientResponseTypes(c, config, validator) - validateOIDCClientResponseModes(c, config, validator) - validateOIDDClientUserinfoAlgorithm(c, config, validator) - validateOIDCClientRedirectURIs(client, validator) + validateOIDCClientConsentMode(c, config, val) + validateOIDCClientSectorIdentifier(client, val) + validateOIDCClientScopes(c, config, val) + validateOIDCClientGrantTypes(c, config, val) + validateOIDCClientResponseTypes(c, config, val) + validateOIDCClientResponseModes(c, config, val) + validateOIDDClientUserinfoAlgorithm(c, config, val) + validateOIDCClientRedirectURIs(client, val) } if invalidID { - validator.Push(fmt.Errorf(errFmtOIDCClientsWithEmptyID)) + val.Push(fmt.Errorf(errFmtOIDCClientsWithEmptyID)) } if duplicateIDs { - validator.Push(fmt.Errorf(errFmtOIDCClientsDuplicateID)) + val.Push(fmt.Errorf(errFmtOIDCClientsDuplicateID)) } } -func validateOIDCClientSectorIdentifier(client schema.OpenIDConnectClientConfiguration, validator *schema.StructValidator) { +func validateOIDCClientSectorIdentifier(client schema.OpenIDConnectClientConfiguration, val *schema.StructValidator) { if client.SectorIdentifier.String() != "" { if utils.IsURLHostComponent(client.SectorIdentifier) || utils.IsURLHostComponentWithPort(client.SectorIdentifier) { return } if client.SectorIdentifier.Scheme != "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "scheme", client.SectorIdentifier.Scheme)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "scheme", client.SectorIdentifier.Scheme)) if client.SectorIdentifier.Path != "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "path", client.SectorIdentifier.Path)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "path", client.SectorIdentifier.Path)) } if client.SectorIdentifier.RawQuery != "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "query", client.SectorIdentifier.RawQuery)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "query", client.SectorIdentifier.RawQuery)) } if client.SectorIdentifier.Fragment != "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "fragment", client.SectorIdentifier.Fragment)) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "fragment", client.SectorIdentifier.Fragment)) } if client.SectorIdentifier.User != nil { if client.SectorIdentifier.User.Username() != "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "username", client.SectorIdentifier.User.Username())) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "username", client.SectorIdentifier.User.Username())) } if _, set := client.SectorIdentifier.User.Password(); set { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifierWithoutValue, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "password")) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifierWithoutValue, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "password")) } } } else if client.SectorIdentifier.Host == "" { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifierHost, client.ID, client.SectorIdentifier.String())) + val.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifierHost, client.ID, client.SectorIdentifier.String())) } } } -func validateOIDCClientScopes(c int, configuration *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { - if len(configuration.Clients[c].Scopes) == 0 { - configuration.Clients[c].Scopes = schema.DefaultOpenIDConnectClientConfiguration.Scopes +func validateOIDCClientConsentMode(c int, config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { + switch { + case utils.IsStringInSlice(config.Clients[c].ConsentMode, []string{"", "auto"}): + if config.Clients[c].ConsentPreConfiguredDuration != nil { + config.Clients[c].ConsentMode = oidc.ClientConsentModePreConfigured.String() + } else { + config.Clients[c].ConsentMode = oidc.ClientConsentModeExplicit.String() + } + case utils.IsStringInSlice(config.Clients[c].ConsentMode, validOIDCClientConsentModes): + break + default: + val.Push(fmt.Errorf(errFmtOIDCClientInvalidConsentMode, config.Clients[c].ID, strings.Join(append(validOIDCClientConsentModes, "auto"), "', '"), config.Clients[c].ConsentMode)) + } + + if config.Clients[c].ConsentMode == oidc.ClientConsentModePreConfigured.String() && config.Clients[c].ConsentPreConfiguredDuration == nil { + config.Clients[c].ConsentPreConfiguredDuration = schema.DefaultOpenIDConnectClientConfiguration.ConsentPreConfiguredDuration + } +} + +func validateOIDCClientScopes(c int, config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { + if len(config.Clients[c].Scopes) == 0 { + config.Clients[c].Scopes = schema.DefaultOpenIDConnectClientConfiguration.Scopes return } - if !utils.IsStringInSlice(oidc.ScopeOpenID, configuration.Clients[c].Scopes) { - configuration.Clients[c].Scopes = append(configuration.Clients[c].Scopes, oidc.ScopeOpenID) + if !utils.IsStringInSlice(oidc.ScopeOpenID, config.Clients[c].Scopes) { + config.Clients[c].Scopes = append(config.Clients[c].Scopes, oidc.ScopeOpenID) } - for _, scope := range configuration.Clients[c].Scopes { + for _, scope := range config.Clients[c].Scopes { if !utils.IsStringInSlice(scope, validOIDCScopes) { - validator.Push(fmt.Errorf( + val.Push(fmt.Errorf( errFmtOIDCClientInvalidEntry, - configuration.Clients[c].ID, "scopes", strings.Join(validOIDCScopes, "', '"), scope)) + config.Clients[c].ID, "scopes", strings.Join(validOIDCScopes, "', '"), scope)) } } } -func validateOIDCClientGrantTypes(c int, configuration *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { - if len(configuration.Clients[c].GrantTypes) == 0 { - configuration.Clients[c].GrantTypes = schema.DefaultOpenIDConnectClientConfiguration.GrantTypes +func validateOIDCClientGrantTypes(c int, config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { + if len(config.Clients[c].GrantTypes) == 0 { + config.Clients[c].GrantTypes = schema.DefaultOpenIDConnectClientConfiguration.GrantTypes return } - for _, grantType := range configuration.Clients[c].GrantTypes { + for _, grantType := range config.Clients[c].GrantTypes { if !utils.IsStringInSlice(grantType, validOIDCGrantTypes) { - validator.Push(fmt.Errorf( + val.Push(fmt.Errorf( errFmtOIDCClientInvalidEntry, - configuration.Clients[c].ID, "grant_types", strings.Join(validOIDCGrantTypes, "', '"), grantType)) + config.Clients[c].ID, "grant_types", strings.Join(validOIDCGrantTypes, "', '"), grantType)) } } } -func validateOIDCClientResponseTypes(c int, configuration *schema.OpenIDConnectConfiguration, _ *schema.StructValidator) { - if len(configuration.Clients[c].ResponseTypes) == 0 { - configuration.Clients[c].ResponseTypes = schema.DefaultOpenIDConnectClientConfiguration.ResponseTypes +func validateOIDCClientResponseTypes(c int, config *schema.OpenIDConnectConfiguration, _ *schema.StructValidator) { + if len(config.Clients[c].ResponseTypes) == 0 { + config.Clients[c].ResponseTypes = schema.DefaultOpenIDConnectClientConfiguration.ResponseTypes return } } -func validateOIDCClientResponseModes(c int, configuration *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { - if len(configuration.Clients[c].ResponseModes) == 0 { - configuration.Clients[c].ResponseModes = schema.DefaultOpenIDConnectClientConfiguration.ResponseModes +func validateOIDCClientResponseModes(c int, config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { + if len(config.Clients[c].ResponseModes) == 0 { + config.Clients[c].ResponseModes = schema.DefaultOpenIDConnectClientConfiguration.ResponseModes return } - for _, responseMode := range configuration.Clients[c].ResponseModes { + for _, responseMode := range config.Clients[c].ResponseModes { if !utils.IsStringInSlice(responseMode, validOIDCResponseModes) { validator.Push(fmt.Errorf( errFmtOIDCClientInvalidEntry, - configuration.Clients[c].ID, "response_modes", strings.Join(validOIDCResponseModes, "', '"), responseMode)) + config.Clients[c].ID, "response_modes", strings.Join(validOIDCResponseModes, "', '"), responseMode)) } } } -func validateOIDDClientUserinfoAlgorithm(c int, configuration *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { - if configuration.Clients[c].UserinfoSigningAlgorithm == "" { - configuration.Clients[c].UserinfoSigningAlgorithm = schema.DefaultOpenIDConnectClientConfiguration.UserinfoSigningAlgorithm - } else if !utils.IsStringInSlice(configuration.Clients[c].UserinfoSigningAlgorithm, validOIDCUserinfoAlgorithms) { - validator.Push(fmt.Errorf(errFmtOIDCClientInvalidUserinfoAlgorithm, - configuration.Clients[c].ID, strings.Join(validOIDCUserinfoAlgorithms, ", "), configuration.Clients[c].UserinfoSigningAlgorithm)) +func validateOIDDClientUserinfoAlgorithm(c int, config *schema.OpenIDConnectConfiguration, val *schema.StructValidator) { + if config.Clients[c].UserinfoSigningAlgorithm == "" { + config.Clients[c].UserinfoSigningAlgorithm = schema.DefaultOpenIDConnectClientConfiguration.UserinfoSigningAlgorithm + } else if !utils.IsStringInSlice(config.Clients[c].UserinfoSigningAlgorithm, validOIDCUserinfoAlgorithms) { + val.Push(fmt.Errorf(errFmtOIDCClientInvalidUserinfoAlgorithm, + config.Clients[c].ID, strings.Join(validOIDCUserinfoAlgorithms, ", "), config.Clients[c].UserinfoSigningAlgorithm)) } } -func validateOIDCClientRedirectURIs(client schema.OpenIDConnectClientConfiguration, validator *schema.StructValidator) { +func validateOIDCClientRedirectURIs(client schema.OpenIDConnectClientConfiguration, val *schema.StructValidator) { for _, redirectURI := range client.RedirectURIs { if redirectURI == oauth2InstalledApp { if client.Public { continue } - validator.Push(fmt.Errorf(errFmtOIDCClientRedirectURIPublic, client.ID, oauth2InstalledApp)) + val.Push(fmt.Errorf(errFmtOIDCClientRedirectURIPublic, client.ID, oauth2InstalledApp)) continue } parsedURL, err := url.Parse(redirectURI) if err != nil { - validator.Push(fmt.Errorf(errFmtOIDCClientRedirectURICantBeParsed, client.ID, redirectURI, err)) + val.Push(fmt.Errorf(errFmtOIDCClientRedirectURICantBeParsed, client.ID, redirectURI, err)) continue } if !parsedURL.IsAbs() || (!client.Public && parsedURL.Scheme == "") { - validator.Push(fmt.Errorf(errFmtOIDCClientRedirectURIAbsolute, client.ID, redirectURI)) + val.Push(fmt.Errorf(errFmtOIDCClientRedirectURIAbsolute, client.ID, redirectURI)) return } } diff --git a/internal/configuration/validator/identity_providers_test.go b/internal/configuration/validator/identity_providers_test.go index b07f9f29f..fad289432 100644 --- a/internal/configuration/validator/identity_providers_test.go +++ b/internal/configuration/validator/identity_providers_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/url" + "strings" "testing" "time" @@ -313,6 +314,23 @@ func TestShouldRaiseErrorWhenOIDCServerClientBadValues(t *testing.T) { fmt.Sprintf(errFmtOIDCClientInvalidSectorIdentifierHost, "client-invalid-sector", "example.com/path?query=abc#fragment"), }, }, + { + Name: "InvalidConsentMode", + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "client-bad-consent-mode", + Secret: MustDecodeSecret("$plaintext$a-secret"), + Policy: policyTwoFactor, + RedirectURIs: []string{ + "https://google.com", + }, + ConsentMode: "cap", + }, + }, + Errors: []string{ + fmt.Sprintf(errFmtOIDCClientInvalidConsentMode, "client-bad-consent-mode", strings.Join(append(validOIDCClientConsentModes, "auto"), "', '"), "cap"), + }, + }, } for _, tc := range testCases { @@ -633,6 +651,8 @@ func TestValidateIdentityProvidersShouldNotRaiseErrorsOnValidPublicClients(t *te } func TestValidateIdentityProvidersShouldSetDefaultValues(t *testing.T) { + timeDay := time.Hour * 24 + validator := schema.NewStructValidator() config := &schema.IdentityProvidersConfiguration{ OIDC: &schema.OpenIDConnectConfiguration{ @@ -645,6 +665,7 @@ func TestValidateIdentityProvidersShouldSetDefaultValues(t *testing.T) { RedirectURIs: []string{ "https://google.com", }, + ConsentPreConfiguredDuration: &timeDay, }, { ID: "b-client", @@ -670,6 +691,30 @@ func TestValidateIdentityProvidersShouldSetDefaultValues(t *testing.T) { "fragment", }, }, + { + ID: "c-client", + Secret: MustDecodeSecret("$plaintext$a-client-secret"), + RedirectURIs: []string{ + "https://google.com", + }, + ConsentMode: "implicit", + }, + { + ID: "d-client", + Secret: MustDecodeSecret("$plaintext$a-client-secret"), + RedirectURIs: []string{ + "https://google.com", + }, + ConsentMode: "explicit", + }, + { + ID: "e-client", + Secret: MustDecodeSecret("$plaintext$a-client-secret"), + RedirectURIs: []string{ + "https://google.com", + }, + ConsentMode: "pre-configured", + }, }, }, } @@ -702,6 +747,15 @@ func TestValidateIdentityProvidersShouldSetDefaultValues(t *testing.T) { assert.Equal(t, "groups", config.OIDC.Clients[1].Scopes[0]) assert.Equal(t, "openid", config.OIDC.Clients[1].Scopes[1]) + // Assert Clients[0] ends up configured with the correct consent mode. + require.NotNil(t, config.OIDC.Clients[0].ConsentPreConfiguredDuration) + assert.Equal(t, time.Hour*24, *config.OIDC.Clients[0].ConsentPreConfiguredDuration) + assert.Equal(t, "pre-configured", config.OIDC.Clients[0].ConsentMode) + + // Assert Clients[1] ends up configured with the correct consent mode. + assert.Nil(t, config.OIDC.Clients[1].ConsentPreConfiguredDuration) + assert.Equal(t, "explicit", config.OIDC.Clients[1].ConsentMode) + // Assert Clients[0] ends up configured with the default GrantTypes. require.Len(t, config.OIDC.Clients[0].GrantTypes, 2) assert.Equal(t, "refresh_token", config.OIDC.Clients[0].GrantTypes[0]) @@ -736,6 +790,15 @@ func TestValidateIdentityProvidersShouldSetDefaultValues(t *testing.T) { assert.Equal(t, time.Minute, config.OIDC.AuthorizeCodeLifespan) assert.Equal(t, time.Hour, config.OIDC.IDTokenLifespan) assert.Equal(t, time.Minute*90, config.OIDC.RefreshTokenLifespan) + + assert.Equal(t, "implicit", config.OIDC.Clients[2].ConsentMode) + assert.Nil(t, config.OIDC.Clients[2].ConsentPreConfiguredDuration) + + assert.Equal(t, "explicit", config.OIDC.Clients[3].ConsentMode) + assert.Nil(t, config.OIDC.Clients[3].ConsentPreConfiguredDuration) + + assert.Equal(t, "pre-configured", config.OIDC.Clients[4].ConsentMode) + assert.Equal(t, schema.DefaultOpenIDConnectClientConfiguration.ConsentPreConfiguredDuration, config.OIDC.Clients[4].ConsentPreConfiguredDuration) } // All valid schemes are supported as defined in https://datatracker.ietf.org/doc/html/rfc8252#section-7.1 diff --git a/internal/configuration/validator/session.go b/internal/configuration/validator/session.go index 7cf7a84b0..17917661d 100644 --- a/internal/configuration/validator/session.go +++ b/internal/configuration/validator/session.go @@ -41,6 +41,8 @@ func validateSession(config *schema.SessionConfiguration, validator *schema.Stru if config.Domain == "" { validator.Push(fmt.Errorf(errFmtSessionOptionRequired, "domain")) + } else if strings.HasPrefix(config.Domain, ".") { + validator.PushWarning(fmt.Errorf("session: option 'domain' has a prefix of '.' which is not supported or intended behaviour: you can use this at your own risk but we recommend removing it")) } if strings.HasPrefix(config.Domain, "*.") { diff --git a/internal/configuration/validator/session_test.go b/internal/configuration/validator/session_test.go index 29ef3d91f..5a08f988a 100644 --- a/internal/configuration/validator/session_test.go +++ b/internal/configuration/validator/session_test.go @@ -49,6 +49,20 @@ func TestShouldSetDefaultSessionValuesWhenNegative(t *testing.T) { assert.Equal(t, schema.DefaultSessionConfiguration.RememberMeDuration, config.RememberMeDuration) } +func TestShouldWarnSessionValuesWhenPotentiallyInvalid(t *testing.T) { + validator := schema.NewStructValidator() + config := newDefaultSessionConfig() + + config.Domain = ".example.com" + + ValidateSession(&config, validator) + + require.Len(t, validator.Warnings(), 1) + assert.Len(t, validator.Errors(), 0) + + assert.EqualError(t, validator.Warnings()[0], "session: option 'domain' has a prefix of '.' which is not supported or intended behaviour: you can use this at your own risk but we recommend removing it") +} + func TestShouldHandleRedisConfigSuccessfully(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultSessionConfig()