From 232503105279f6e78faa5d4e627dab47b81efade Mon Sep 17 00:00:00 2001 From: James Elliott Date: Sat, 3 Sep 2022 11:51:02 +1000 Subject: [PATCH] refactor: clean up uri checking functions (#3943) --- .../authorization/access_control_subjects.go | 2 +- internal/authorization/types_test.go | 8 ++-- internal/commands/acl.go | 2 +- .../validator/configuration_test.go | 2 +- .../validator/identity_providers.go | 3 +- .../handler_checks_safe_redirection.go | 2 +- internal/handlers/handler_logout.go | 4 +- .../handler_oidc_authorization_consent.go | 8 ++-- internal/handlers/response.go | 6 +-- internal/middlewares/authelia_context.go | 6 +-- internal/middlewares/const.go | 9 +++- internal/middlewares/cors.go | 2 +- internal/model/oidc.go | 2 +- internal/model/totp_configuration.go | 6 +-- internal/oidc/keys.go | 14 +++--- internal/oidc/store.go | 10 ++--- internal/session/user_session.go | 2 +- internal/storage/sql_provider_encryption.go | 16 +++---- internal/storage/sql_provider_schema.go | 2 +- internal/utils/const.go | 1 + internal/utils/strings.go | 2 +- internal/utils/strings_test.go | 6 +-- internal/utils/url.go | 44 ++++++++++++------- internal/utils/url_test.go | 12 ++--- 24 files changed, 96 insertions(+), 75 deletions(-) diff --git a/internal/authorization/access_control_subjects.go b/internal/authorization/access_control_subjects.go index d7e1f0f9f..4159f0d24 100644 --- a/internal/authorization/access_control_subjects.go +++ b/internal/authorization/access_control_subjects.go @@ -19,7 +19,7 @@ func (acs *AccessControlSubjects) AddSubject(subjectRule string) { } // IsMatch returns true if the ACL subjects match the subject properties. -func (acs AccessControlSubjects) IsMatch(subject Subject) (match bool) { +func (acs *AccessControlSubjects) IsMatch(subject Subject) (match bool) { for _, rule := range acs.Subjects { if !rule.IsMatch(subject) { return false diff --git a/internal/authorization/types_test.go b/internal/authorization/types_test.go index 91b6a608a..c598af59a 100644 --- a/internal/authorization/types_test.go +++ b/internal/authorization/types_test.go @@ -9,7 +9,7 @@ import ( ) func TestShouldAppendQueryParamToURL(t *testing.T) { - targetURL, err := url.Parse("https://domain.example.com/api?type=none") + targetURL, err := url.ParseRequestURI("https://domain.example.com/api?type=none") require.NoError(t, err) @@ -22,7 +22,7 @@ func TestShouldAppendQueryParamToURL(t *testing.T) { } func TestShouldCreateNewObjectFromRaw(t *testing.T) { - targetURL, err := url.Parse("https://domain.example.com/api") + targetURL, err := url.ParseRequestURI("https://domain.example.com/api") require.NoError(t, err) @@ -55,7 +55,7 @@ func TestShouldCleanURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.have, func(t *testing.T) { - have, err := url.Parse(tc.have + tc.havePath) + have, err := url.ParseRequestURI(tc.have + tc.havePath) require.NoError(t, err) object := NewObject(have, tc.method) @@ -66,7 +66,7 @@ func TestShouldCleanURL(t *testing.T) { assert.Equal(t, tc.expectedPathClean, object.Path) assert.Equal(t, tc.method, object.Method) - have, err = url.Parse(tc.have) + have, err = url.ParseRequestURI(tc.have) require.NoError(t, err) path, err := url.ParseRequestURI(tc.havePath) diff --git a/internal/commands/acl.go b/internal/commands/acl.go index 05ef8c75e..0b0851b55 100644 --- a/internal/commands/acl.go +++ b/internal/commands/acl.go @@ -208,7 +208,7 @@ func getSubjectAndObjectFromFlags(cmd *cobra.Command) (subject authorization.Sub return subject, object, err } - parsedURL, err := url.Parse(requestURL) + parsedURL, err := url.ParseRequestURI(requestURL) if err != nil { return subject, object, err } diff --git a/internal/configuration/validator/configuration_test.go b/internal/configuration/validator/configuration_test.go index a93fd4aa1..4021c1f33 100644 --- a/internal/configuration/validator/configuration_test.go +++ b/internal/configuration/validator/configuration_test.go @@ -99,7 +99,7 @@ func TestShouldRaiseErrorWithBadDefaultRedirectionURL(t *testing.T) { require.Len(t, validator.Errors(), 1) require.Len(t, validator.Warnings(), 1) - assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: the url 'bad_default_redirection_url' is not absolute because it doesn't start with a scheme like 'ldap://' or 'ldaps://'") + assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: could not parse 'bad_default_redirection_url' as a URL") assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests") } diff --git a/internal/configuration/validator/identity_providers.go b/internal/configuration/validator/identity_providers.go index 675745450..4cdfa3800 100644 --- a/internal/configuration/validator/identity_providers.go +++ b/internal/configuration/validator/identity_providers.go @@ -95,7 +95,7 @@ func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfigura func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) { for _, client := range config.Clients { for _, redirectURI := range client.RedirectURIs { - uri, err := url.Parse(redirectURI) + uri, err := url.ParseRequestURI(redirectURI) if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" { continue } @@ -116,6 +116,7 @@ func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, } } } + func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { invalidID, duplicateIDs := false, false diff --git a/internal/handlers/handler_checks_safe_redirection.go b/internal/handlers/handler_checks_safe_redirection.go index 83b7918dd..d53055965 100644 --- a/internal/handlers/handler_checks_safe_redirection.go +++ b/internal/handlers/handler_checks_safe_redirection.go @@ -25,7 +25,7 @@ func CheckSafeRedirectionPOST(ctx *middlewares.AutheliaCtx) { return } - safe, err := utils.IsRedirectionURISafe(reqBody.URI, ctx.Configuration.Session.Domain) + safe, err := utils.IsURIStringSafeRedirection(reqBody.URI, ctx.Configuration.Session.Domain) if err != nil { ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed) return diff --git a/internal/handlers/handler_logout.go b/internal/handlers/handler_logout.go index 45ae564ba..0f491dd87 100644 --- a/internal/handlers/handler_logout.go +++ b/internal/handlers/handler_logout.go @@ -31,9 +31,9 @@ func LogoutPOST(ctx *middlewares.AutheliaCtx) { ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed) } - redirectionURL, err := url.Parse(body.TargetURL) + redirectionURL, err := url.ParseRequestURI(body.TargetURL) if err == nil { - responseBody.SafeTargetURL = utils.URLDomainHasSuffix(*redirectionURL, ctx.Configuration.Session.Domain) + responseBody.SafeTargetURL = utils.IsURISafeRedirection(redirectionURL, ctx.Configuration.Session.Domain) } if body.TargetURL != "" { diff --git a/internal/handlers/handler_oidc_authorization_consent.go b/internal/handlers/handler_oidc_authorization_consent.go index 1a4b40bd6..3fdfb25d1 100644 --- a/internal/handlers/handler_oidc_authorization_consent.go +++ b/internal/handlers/handler_oidc_authorization_consent.go @@ -29,7 +29,7 @@ func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string err error ) - if issuer, err = url.Parse(rootURI); err != nil { + if issuer, err = url.ParseRequestURI(rootURI); err != nil { ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer.")) return nil, true @@ -178,7 +178,7 @@ func handleOIDCAuthorizationConsentRedirect(ctx *middlewares.AutheliaCtx, issuer var location *url.URL if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { - location, _ = url.Parse(issuer.String()) + location, _ = url.ParseRequestURI(issuer.String()) location.Path = path.Join(location.Path, "/consent") query := location.Query() @@ -229,9 +229,9 @@ func verifyOIDCUserAuthorizedForConsent(ctx *middlewares.AutheliaCtx, client *oi } func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) { - redirectURL, _ = url.Parse(issuer.String()) + redirectURL, _ = url.ParseRequestURI(issuer.String()) - authorizationURL, _ := url.Parse(issuer.String()) + authorizationURL, _ := url.ParseRequestURI(issuer.String()) authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath) authorizationURL.RawQuery = requester.GetRequestForm().Encode() diff --git a/internal/handlers/response.go b/internal/handlers/response.go index 2cd07ddbb..e1a9d970e 100644 --- a/internal/handlers/response.go +++ b/internal/handlers/response.go @@ -104,7 +104,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI, requestMethod st return } - if !utils.URLDomainHasSuffix(*targetURL, ctx.Configuration.Session.Domain) { + if !utils.IsURISafeRedirection(targetURL, ctx.Configuration.Session.Domain) { ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI) if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" { @@ -147,7 +147,7 @@ func Handle2FAResponse(ctx *middlewares.AutheliaCtx, targetURI string) { var safe bool - if safe, err = utils.IsRedirectionURISafe(targetURI, ctx.Configuration.Session.Domain); err != nil { + if safe, err = utils.IsURIStringSafeRedirection(targetURI, ctx.Configuration.Session.Domain); err != nil { ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed) return @@ -176,7 +176,7 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba referer := ctx.Request.Header.Referer() if referer != nil { - refererURL, err := url.Parse(string(referer)) + refererURL, err := url.ParseRequestURI(string(referer)) if err == nil { requestURI = refererURL.Query().Get("rd") requestMethod = refererURL.Query().Get("rm") diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 584c3df85..5189e1435 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -195,7 +195,7 @@ func (ctx *AutheliaCtx) ExternalRootURL() (string, error) { externalRootURL := fmt.Sprintf("%s://%s", protocol, host) if base := ctx.BasePath(); base != "" { - externalBaseURL, err := url.Parse(externalRootURL) + externalBaseURL, err := url.ParseRequestURI(externalRootURL) if err != nil { return "", err } @@ -317,14 +317,14 @@ func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) { } // IsXHR returns true if the request is a XMLHttpRequest. -func (ctx AutheliaCtx) IsXHR() (xhr bool) { +func (ctx *AutheliaCtx) IsXHR() (xhr bool) { requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith) return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR) } // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type. -func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { +func (ctx *AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",") for i, accept := range accepts { diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index 3975df0af..bbbbb8235 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -62,9 +62,14 @@ var ( headerValueCohort = []byte("interest-cohort=()") ) +const ( + strProtoHTTPS = "https" + strProtoHTTP = "http" +) + var ( - protoHTTPS = []byte("https") - protoHTTP = []byte("http") + protoHTTPS = []byte(strProtoHTTPS) + protoHTTP = []byte(strProtoHTTP) // UserValueKeyBaseURL is the User Value key where we store the Base URL. UserValueKeyBaseURL = []byte("base_url") diff --git a/internal/middlewares/cors.go b/internal/middlewares/cors.go index 7936e6f70..d5a7234a4 100644 --- a/internal/middlewares/cors.go +++ b/internal/middlewares/cors.go @@ -265,7 +265,7 @@ func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) { origin := ctx.Request.Header.PeekBytes(headerOrigin) // Skip processing of any `https` scheme URL that has not expressly been configured. - if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) { + if originURL, err = url.ParseRequestURI(string(origin)); err != nil || (originURL.Scheme != strProtoHTTPS && p.origins == nil) { return } diff --git a/internal/model/oidc.go b/internal/model/oidc.go index f41816c94..c25e46eb4 100644 --- a/internal/model/oidc.go +++ b/internal/model/oidc.go @@ -188,7 +188,7 @@ func (s *OAuth2Session) SetSubject(subject string) { } // ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage. -func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) { +func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) { sessionData := s.Session if session != nil { diff --git a/internal/model/totp_configuration.go b/internal/model/totp_configuration.go index 70387b78c..74777e4a2 100644 --- a/internal/model/totp_configuration.go +++ b/internal/model/totp_configuration.go @@ -23,7 +23,7 @@ type TOTPConfiguration struct { } // URI shows the configuration in the URI representation. -func (c TOTPConfiguration) URI() (uri string) { +func (c *TOTPConfiguration) URI() (uri string) { v := url.Values{} v.Set("secret", string(c.Secret)) v.Set("issuer", c.Issuer) @@ -47,12 +47,12 @@ func (c *TOTPConfiguration) UpdateSignInInfo(now time.Time) { } // Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL. -func (c TOTPConfiguration) Key() (key *otp.Key, err error) { +func (c *TOTPConfiguration) Key() (key *otp.Key, err error) { return otp.NewKeyFromURL(c.URI()) } // Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key. -func (c TOTPConfiguration) Image(width, height int) (img image.Image, err error) { +func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error) { var key *otp.Key if key, err = c.Key(); err != nil { diff --git a/internal/oidc/keys.go b/internal/oidc/keys.go index 946526252..d5351061d 100644 --- a/internal/oidc/keys.go +++ b/internal/oidc/keys.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/ory/fosite/token/jwt" - "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/utils" @@ -38,17 +38,17 @@ func NewKeyManager() (manager *KeyManager) { } // Strategy returns the RS256JWTStrategy. -func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) { +func (m *KeyManager) Strategy() (strategy *RS256JWTStrategy) { return m.strategy } // GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types. -func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) { +func (m *KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) { return m.keySet } // GetActiveWebKey obtains the currently active jose.JSONWebKey. -func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) { +func (m *KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) { webKeys := m.keySet.Key(m.activeKeyID) if len(webKeys) == 1 { return &webKeys[0], nil @@ -62,12 +62,12 @@ func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) { } // GetActiveKeyID returns the key id of the currently active key. -func (m KeyManager) GetActiveKeyID() (keyID string) { +func (m *KeyManager) GetActiveKeyID() (keyID string) { return m.activeKeyID } // GetActiveKey returns the rsa.PublicKey of the currently active key. -func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) { +func (m *KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) { if key, ok := m.keys[m.activeKeyID]; ok { return &key.PublicKey, nil } @@ -76,7 +76,7 @@ func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) { } // GetActivePrivateKey returns the rsa.PrivateKey of the currently active key. -func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) { +func (m *KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) { if key, ok := m.keys[m.activeKeyID]; ok { return key, nil } diff --git a/internal/oidc/store.go b/internal/oidc/store.go index 2331df0fa..af3fff7d0 100644 --- a/internal/oidc/store.go +++ b/internal/oidc/store.go @@ -38,7 +38,7 @@ func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider s } // GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username. -func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { +func (s *OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil { return nil, err } else if opaqueID == nil { @@ -55,7 +55,7 @@ func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, } // GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it. -func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) { +func (s *OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) { var opaqueID *model.UserOpaqueIdentifier if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil { @@ -66,7 +66,7 @@ func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username s } // GetClientPolicy retrieves the policy from the client with the matching provided id. -func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) { +func (s *OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) { client, err := s.GetFullClient(id) if err != nil { return authorization.TwoFactor @@ -76,7 +76,7 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve } // GetFullClient returns a fosite.Client asserted as an Client matching the provided id. -func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) { +func (s *OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) { client, ok := s.clients[id] if !ok { return nil, fosite.ErrNotFound @@ -86,7 +86,7 @@ func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) } // IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map. -func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) { +func (s *OpenIDConnectStore) IsValidClientID(id string) (valid bool) { _, err := s.GetFullClient(id) return err == nil diff --git a/internal/session/user_session.go b/internal/session/user_session.go index 24ee85067..d691ec695 100644 --- a/internal/session/user_session.go +++ b/internal/session/user_session.go @@ -61,7 +61,7 @@ func (s *UserSession) SetTwoFactorWebauthn(now time.Time, userPresence, userVeri } // AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level. -func (s UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) { +func (s *UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) { switch level { case authorization.OneFactor: return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil diff --git a/internal/storage/sql_provider_encryption.go b/internal/storage/sql_provider_encryption.go index 64974b19a..338bb27bf 100644 --- a/internal/storage/sql_provider_encryption.go +++ b/internal/storage/sql_provider_encryption.go @@ -139,7 +139,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool errs = append(errs, err) } - if err = p.schemaEncryptionCheckU2F(ctx); err != nil { + if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil { errs = append(errs, err) } } @@ -210,9 +210,9 @@ func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) return nil } -func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) { +func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) { var ( - device model.U2FDevice + device model.WebauthnDevice row int invalid int total int @@ -226,7 +226,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil { _ = rows.Close() - return fmt.Errorf("error selecting U2F devices: %w", err) + return fmt.Errorf("error selecting Webauthn devices: %w", err) } row = 0 @@ -237,7 +237,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) if err = rows.StructScan(&device); err != nil { _ = rows.Close() - return fmt.Errorf("error scanning U2F device to struct: %w", err) + return fmt.Errorf("error scanning Webauthn device to struct: %w", err) } if _, err = p.decrypt(device.PublicKey); err != nil { @@ -253,17 +253,17 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) } if invalid != 0 { - return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total) + return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total) } return nil } -func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { +func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { return utils.Encrypt(clearText, &p.key) } -func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) { +func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) { return utils.Decrypt(cipherText, &p.key) } diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index 490aba5c6..e01aee1ad 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -246,7 +246,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.Sc return p.schemaMigrateFinalize(ctx, migration) } -func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) { +func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) { return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After()) } diff --git a/internal/utils/const.go b/internal/utils/const.go index b9f0b33fa..d72aa5cd7 100644 --- a/internal/utils/const.go +++ b/internal/utils/const.go @@ -34,6 +34,7 @@ const ( const ( period = "." https = "https" + wss = "wss" ) // X.509 consts. diff --git a/internal/utils/strings.go b/internal/utils/strings.go index d776eb7e3..fd13f66d9 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -16,7 +16,7 @@ import ( // IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error // describing why. func IsStringAbsURL(input string) (err error) { - parsedURL, err := url.Parse(input) + parsedURL, err := url.ParseRequestURI(input) if err != nil { return fmt.Errorf("could not parse '%s' as a URL", input) } diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index 5d8e5f01a..0e2b92651 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -235,13 +235,13 @@ func TestStringSliceURLConversionFuncs(t *testing.T) { func TestIsURLInSlice(t *testing.T) { urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"}) - google, err := url.Parse("https://google.com") + google, err := url.ParseRequestURI("https://google.com") assert.NoError(t, err) - microsoft, err := url.Parse("https://microsoft.com") + microsoft, err := url.ParseRequestURI("https://microsoft.com") assert.NoError(t, err) - example, err := url.Parse("https://example.com") + example, err := url.ParseRequestURI("https://example.com") assert.NoError(t, err) assert.True(t, IsURLInSlice(*google, urls)) diff --git a/internal/utils/url.go b/internal/utils/url.go index 09d7e8e70..dbe472fc8 100644 --- a/internal/utils/url.go +++ b/internal/utils/url.go @@ -30,12 +30,37 @@ func URLPathFullClean(u *url.URL) (output string) { } } -// URLDomainHasSuffix determines whether the uri has a suffix of the domain value. -func URLDomainHasSuffix(uri url.URL, domain string) bool { - if uri.Scheme != https { - return false +// IsURIStringSafeRedirection determines whether the URI is safe to be redirected to. +func IsURIStringSafeRedirection(uri, protectedDomain string) (safe bool, err error) { + var parsedURI *url.URL + + if parsedURI, err = url.ParseRequestURI(uri); err != nil { + return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err) } + return parsedURI != nil && IsURISafeRedirection(parsedURI, protectedDomain), nil +} + +// IsURISafeRedirection returns true if the URI passes the IsURISecure and HasURIDomainSuffix, i.e. if the scheme is +// secure and the given URI has a hostname that is either exactly equal to the given domain or if it has a suffix of the +// domain prefixed with a period. +func IsURISafeRedirection(uri *url.URL, domain string) bool { + return IsURISecure(uri) && HasURIDomainSuffix(uri, domain) +} + +// IsURISecure returns true if the URI has a secure schemes (https or wss). +func IsURISecure(uri *url.URL) bool { + switch uri.Scheme { + case https, wss: + return true + default: + return false + } +} + +// HasURIDomainSuffix returns true if the URI hostname is equal to the domain or if it has a suffix of the domain +// prefixed with a period. +func HasURIDomainSuffix(uri *url.URL, domain string) bool { if uri.Hostname() == domain { return true } @@ -46,14 +71,3 @@ func URLDomainHasSuffix(uri url.URL, domain string) bool { return false } - -// IsRedirectionURISafe determines whether the URI is safe to be redirected to. -func IsRedirectionURISafe(uri, protectedDomain string) (safe bool, err error) { - var parsedURI *url.URL - - if parsedURI, err = url.ParseRequestURI(uri); err != nil { - return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err) - } - - return parsedURI != nil && URLDomainHasSuffix(*parsedURI, protectedDomain), nil -} diff --git a/internal/utils/url_test.go b/internal/utils/url_test.go index 0873b1b26..0cbb8a0dc 100644 --- a/internal/utils/url_test.go +++ b/internal/utils/url_test.go @@ -29,7 +29,7 @@ func TestURLPathFullClean(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - u, err := url.Parse(tc.have) + u, err := url.ParseRequestURI(tc.have) require.NoError(t, err) actual := URLPathFullClean(u) @@ -41,7 +41,7 @@ func TestURLPathFullClean(t *testing.T) { func isURLSafe(requestURI string, domain string) bool { //nolint:unparam u, _ := url.ParseRequestURI(requestURI) - return URLDomainHasSuffix(*u, domain) + return IsURISafeRedirection(u, domain) } func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) { @@ -62,22 +62,22 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) { } func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) { - _, err := IsRedirectionURISafe("http//invalid", "example.com") + _, err := IsURIStringSafeRedirection("http//invalid", "example.com") assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request") } func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) { - valid, err := IsRedirectionURISafe("http://myurl.com/myresource", "example.com") + valid, err := IsURIStringSafeRedirection("http://myurl.com/myresource", "example.com") assert.NoError(t, err) assert.False(t, valid) } func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) { - valid, err := IsRedirectionURISafe("http://myurl.example.com/myresource", "example.com") + valid, err := IsURIStringSafeRedirection("http://myurl.example.com/myresource", "example.com") assert.NoError(t, err) assert.False(t, valid) - valid, err = IsRedirectionURISafe("http://example.com/myresource", "example.com") + valid, err = IsURIStringSafeRedirection("http://example.com/myresource", "example.com") assert.NoError(t, err) assert.False(t, valid) }