refactor: clean up uri checking functions (#3943)
parent
02636966a8
commit
2325031052
|
@ -19,7 +19,7 @@ func (acs *AccessControlSubjects) AddSubject(subjectRule string) {
|
|||
}
|
||||
|
||||
// IsMatch returns true if the ACL subjects match the subject properties.
|
||||
func (acs AccessControlSubjects) IsMatch(subject Subject) (match bool) {
|
||||
func (acs *AccessControlSubjects) IsMatch(subject Subject) (match bool) {
|
||||
for _, rule := range acs.Subjects {
|
||||
if !rule.IsMatch(subject) {
|
||||
return false
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func TestShouldAppendQueryParamToURL(t *testing.T) {
|
||||
targetURL, err := url.Parse("https://domain.example.com/api?type=none")
|
||||
targetURL, err := url.ParseRequestURI("https://domain.example.com/api?type=none")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -22,7 +22,7 @@ func TestShouldAppendQueryParamToURL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldCreateNewObjectFromRaw(t *testing.T) {
|
||||
targetURL, err := url.Parse("https://domain.example.com/api")
|
||||
targetURL, err := url.ParseRequestURI("https://domain.example.com/api")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -55,7 +55,7 @@ func TestShouldCleanURL(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.have, func(t *testing.T) {
|
||||
have, err := url.Parse(tc.have + tc.havePath)
|
||||
have, err := url.ParseRequestURI(tc.have + tc.havePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
object := NewObject(have, tc.method)
|
||||
|
@ -66,7 +66,7 @@ func TestShouldCleanURL(t *testing.T) {
|
|||
assert.Equal(t, tc.expectedPathClean, object.Path)
|
||||
assert.Equal(t, tc.method, object.Method)
|
||||
|
||||
have, err = url.Parse(tc.have)
|
||||
have, err = url.ParseRequestURI(tc.have)
|
||||
require.NoError(t, err)
|
||||
|
||||
path, err := url.ParseRequestURI(tc.havePath)
|
||||
|
|
|
@ -208,7 +208,7 @@ func getSubjectAndObjectFromFlags(cmd *cobra.Command) (subject authorization.Sub
|
|||
return subject, object, err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(requestURL)
|
||||
parsedURL, err := url.ParseRequestURI(requestURL)
|
||||
if err != nil {
|
||||
return subject, object, err
|
||||
}
|
||||
|
|
|
@ -99,7 +99,7 @@ func TestShouldRaiseErrorWithBadDefaultRedirectionURL(t *testing.T) {
|
|||
require.Len(t, validator.Errors(), 1)
|
||||
require.Len(t, validator.Warnings(), 1)
|
||||
|
||||
assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: the url 'bad_default_redirection_url' is not absolute because it doesn't start with a scheme like 'ldap://' or 'ldaps://'")
|
||||
assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: could not parse 'bad_default_redirection_url' as a URL")
|
||||
assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests")
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfigura
|
|||
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
|
||||
for _, client := range config.Clients {
|
||||
for _, redirectURI := range client.RedirectURIs {
|
||||
uri, err := url.Parse(redirectURI)
|
||||
uri, err := url.ParseRequestURI(redirectURI)
|
||||
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
|
||||
continue
|
||||
}
|
||||
|
@ -116,6 +116,7 @@ func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||
invalidID, duplicateIDs := false, false
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ func CheckSafeRedirectionPOST(ctx *middlewares.AutheliaCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
safe, err := utils.IsRedirectionURISafe(reqBody.URI, ctx.Configuration.Session.Domain)
|
||||
safe, err := utils.IsURIStringSafeRedirection(reqBody.URI, ctx.Configuration.Session.Domain)
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed)
|
||||
return
|
||||
|
|
|
@ -31,9 +31,9 @@ func LogoutPOST(ctx *middlewares.AutheliaCtx) {
|
|||
ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed)
|
||||
}
|
||||
|
||||
redirectionURL, err := url.Parse(body.TargetURL)
|
||||
redirectionURL, err := url.ParseRequestURI(body.TargetURL)
|
||||
if err == nil {
|
||||
responseBody.SafeTargetURL = utils.URLDomainHasSuffix(*redirectionURL, ctx.Configuration.Session.Domain)
|
||||
responseBody.SafeTargetURL = utils.IsURISafeRedirection(redirectionURL, ctx.Configuration.Session.Domain)
|
||||
}
|
||||
|
||||
if body.TargetURL != "" {
|
||||
|
|
|
@ -29,7 +29,7 @@ func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string
|
|||
err error
|
||||
)
|
||||
|
||||
if issuer, err = url.Parse(rootURI); err != nil {
|
||||
if issuer, err = url.ParseRequestURI(rootURI); err != nil {
|
||||
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer."))
|
||||
|
||||
return nil, true
|
||||
|
@ -178,7 +178,7 @@ func handleOIDCAuthorizationConsentRedirect(ctx *middlewares.AutheliaCtx, issuer
|
|||
var location *url.URL
|
||||
|
||||
if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
|
||||
location, _ = url.Parse(issuer.String())
|
||||
location, _ = url.ParseRequestURI(issuer.String())
|
||||
location.Path = path.Join(location.Path, "/consent")
|
||||
|
||||
query := location.Query()
|
||||
|
@ -229,9 +229,9 @@ func verifyOIDCUserAuthorizedForConsent(ctx *middlewares.AutheliaCtx, client *oi
|
|||
}
|
||||
|
||||
func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) {
|
||||
redirectURL, _ = url.Parse(issuer.String())
|
||||
redirectURL, _ = url.ParseRequestURI(issuer.String())
|
||||
|
||||
authorizationURL, _ := url.Parse(issuer.String())
|
||||
authorizationURL, _ := url.ParseRequestURI(issuer.String())
|
||||
|
||||
authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath)
|
||||
authorizationURL.RawQuery = requester.GetRequestForm().Encode()
|
||||
|
|
|
@ -104,7 +104,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI, requestMethod st
|
|||
return
|
||||
}
|
||||
|
||||
if !utils.URLDomainHasSuffix(*targetURL, ctx.Configuration.Session.Domain) {
|
||||
if !utils.IsURISafeRedirection(targetURL, ctx.Configuration.Session.Domain) {
|
||||
ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI)
|
||||
|
||||
if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" {
|
||||
|
@ -147,7 +147,7 @@ func Handle2FAResponse(ctx *middlewares.AutheliaCtx, targetURI string) {
|
|||
|
||||
var safe bool
|
||||
|
||||
if safe, err = utils.IsRedirectionURISafe(targetURI, ctx.Configuration.Session.Domain); err != nil {
|
||||
if safe, err = utils.IsURIStringSafeRedirection(targetURI, ctx.Configuration.Session.Domain); err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed)
|
||||
|
||||
return
|
||||
|
@ -176,7 +176,7 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba
|
|||
|
||||
referer := ctx.Request.Header.Referer()
|
||||
if referer != nil {
|
||||
refererURL, err := url.Parse(string(referer))
|
||||
refererURL, err := url.ParseRequestURI(string(referer))
|
||||
if err == nil {
|
||||
requestURI = refererURL.Query().Get("rd")
|
||||
requestMethod = refererURL.Query().Get("rm")
|
||||
|
|
|
@ -195,7 +195,7 @@ func (ctx *AutheliaCtx) ExternalRootURL() (string, error) {
|
|||
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
|
||||
|
||||
if base := ctx.BasePath(); base != "" {
|
||||
externalBaseURL, err := url.Parse(externalRootURL)
|
||||
externalBaseURL, err := url.ParseRequestURI(externalRootURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -317,14 +317,14 @@ func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
|
|||
}
|
||||
|
||||
// IsXHR returns true if the request is a XMLHttpRequest.
|
||||
func (ctx AutheliaCtx) IsXHR() (xhr bool) {
|
||||
func (ctx *AutheliaCtx) IsXHR() (xhr bool) {
|
||||
requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
|
||||
|
||||
return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
|
||||
}
|
||||
|
||||
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
|
||||
func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
||||
func (ctx *AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
||||
accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
|
||||
|
||||
for i, accept := range accepts {
|
||||
|
|
|
@ -62,9 +62,14 @@ var (
|
|||
headerValueCohort = []byte("interest-cohort=()")
|
||||
)
|
||||
|
||||
const (
|
||||
strProtoHTTPS = "https"
|
||||
strProtoHTTP = "http"
|
||||
)
|
||||
|
||||
var (
|
||||
protoHTTPS = []byte("https")
|
||||
protoHTTP = []byte("http")
|
||||
protoHTTPS = []byte(strProtoHTTPS)
|
||||
protoHTTP = []byte(strProtoHTTP)
|
||||
|
||||
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
||||
UserValueKeyBaseURL = []byte("base_url")
|
||||
|
|
|
@ -265,7 +265,7 @@ func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
|
|||
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
||||
|
||||
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
||||
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
|
||||
if originURL, err = url.ParseRequestURI(string(origin)); err != nil || (originURL.Scheme != strProtoHTTPS && p.origins == nil) {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -188,7 +188,7 @@ func (s *OAuth2Session) SetSubject(subject string) {
|
|||
}
|
||||
|
||||
// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
|
||||
func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
|
||||
func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
|
||||
sessionData := s.Session
|
||||
|
||||
if session != nil {
|
||||
|
|
|
@ -23,7 +23,7 @@ type TOTPConfiguration struct {
|
|||
}
|
||||
|
||||
// URI shows the configuration in the URI representation.
|
||||
func (c TOTPConfiguration) URI() (uri string) {
|
||||
func (c *TOTPConfiguration) URI() (uri string) {
|
||||
v := url.Values{}
|
||||
v.Set("secret", string(c.Secret))
|
||||
v.Set("issuer", c.Issuer)
|
||||
|
@ -47,12 +47,12 @@ func (c *TOTPConfiguration) UpdateSignInInfo(now time.Time) {
|
|||
}
|
||||
|
||||
// Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL.
|
||||
func (c TOTPConfiguration) Key() (key *otp.Key, err error) {
|
||||
func (c *TOTPConfiguration) Key() (key *otp.Key, err error) {
|
||||
return otp.NewKeyFromURL(c.URI())
|
||||
}
|
||||
|
||||
// Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key.
|
||||
func (c TOTPConfiguration) Image(width, height int) (img image.Image, err error) {
|
||||
func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error) {
|
||||
var key *otp.Key
|
||||
|
||||
if key, err = c.Key(); err != nil {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/ory/fosite/token/jwt"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
|
@ -38,17 +38,17 @@ func NewKeyManager() (manager *KeyManager) {
|
|||
}
|
||||
|
||||
// Strategy returns the RS256JWTStrategy.
|
||||
func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) {
|
||||
func (m *KeyManager) Strategy() (strategy *RS256JWTStrategy) {
|
||||
return m.strategy
|
||||
}
|
||||
|
||||
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
|
||||
func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
|
||||
func (m *KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
|
||||
return m.keySet
|
||||
}
|
||||
|
||||
// GetActiveWebKey obtains the currently active jose.JSONWebKey.
|
||||
func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
||||
func (m *KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
||||
webKeys := m.keySet.Key(m.activeKeyID)
|
||||
if len(webKeys) == 1 {
|
||||
return &webKeys[0], nil
|
||||
|
@ -62,12 +62,12 @@ func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
|||
}
|
||||
|
||||
// GetActiveKeyID returns the key id of the currently active key.
|
||||
func (m KeyManager) GetActiveKeyID() (keyID string) {
|
||||
func (m *KeyManager) GetActiveKeyID() (keyID string) {
|
||||
return m.activeKeyID
|
||||
}
|
||||
|
||||
// GetActiveKey returns the rsa.PublicKey of the currently active key.
|
||||
func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
||||
func (m *KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
||||
if key, ok := m.keys[m.activeKeyID]; ok {
|
||||
return &key.PublicKey, nil
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
|||
}
|
||||
|
||||
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
|
||||
func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
|
||||
func (m *KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
|
||||
if key, ok := m.keys[m.activeKeyID]; ok {
|
||||
return key, nil
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider s
|
|||
}
|
||||
|
||||
// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
|
||||
func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
||||
func (s *OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
||||
if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
|
||||
return nil, err
|
||||
} else if opaqueID == nil {
|
||||
|
@ -55,7 +55,7 @@ func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID,
|
|||
}
|
||||
|
||||
// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
|
||||
func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
|
||||
func (s *OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
|
||||
var opaqueID *model.UserOpaqueIdentifier
|
||||
|
||||
if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
|
||||
|
@ -66,7 +66,7 @@ func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username s
|
|||
}
|
||||
|
||||
// GetClientPolicy retrieves the policy from the client with the matching provided id.
|
||||
func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
|
||||
func (s *OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
|
||||
client, err := s.GetFullClient(id)
|
||||
if err != nil {
|
||||
return authorization.TwoFactor
|
||||
|
@ -76,7 +76,7 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve
|
|||
}
|
||||
|
||||
// GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
|
||||
func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
|
||||
func (s *OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
|
||||
client, ok := s.clients[id]
|
||||
if !ok {
|
||||
return nil, fosite.ErrNotFound
|
||||
|
@ -86,7 +86,7 @@ func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error)
|
|||
}
|
||||
|
||||
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
|
||||
func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
|
||||
func (s *OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
|
||||
_, err := s.GetFullClient(id)
|
||||
|
||||
return err == nil
|
||||
|
|
|
@ -61,7 +61,7 @@ func (s *UserSession) SetTwoFactorWebauthn(now time.Time, userPresence, userVeri
|
|||
}
|
||||
|
||||
// AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level.
|
||||
func (s UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) {
|
||||
func (s *UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) {
|
||||
switch level {
|
||||
case authorization.OneFactor:
|
||||
return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil
|
||||
|
|
|
@ -139,7 +139,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
|
||||
if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
@ -210,9 +210,9 @@ func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
|
||||
func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) {
|
||||
var (
|
||||
device model.U2FDevice
|
||||
device model.WebauthnDevice
|
||||
row int
|
||||
invalid int
|
||||
total int
|
||||
|
@ -226,7 +226,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
|||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return fmt.Errorf("error selecting U2F devices: %w", err)
|
||||
return fmt.Errorf("error selecting Webauthn devices: %w", err)
|
||||
}
|
||||
|
||||
row = 0
|
||||
|
@ -237,7 +237,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
|||
|
||||
if err = rows.StructScan(&device); err != nil {
|
||||
_ = rows.Close()
|
||||
return fmt.Errorf("error scanning U2F device to struct: %w", err)
|
||||
return fmt.Errorf("error scanning Webauthn device to struct: %w", err)
|
||||
}
|
||||
|
||||
if _, err = p.decrypt(device.PublicKey); err != nil {
|
||||
|
@ -253,17 +253,17 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
|||
}
|
||||
|
||||
if invalid != 0 {
|
||||
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
|
||||
return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||
return utils.Encrypt(clearText, &p.key)
|
||||
}
|
||||
|
||||
func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
||||
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
||||
return utils.Decrypt(cipherText, &p.key)
|
||||
}
|
||||
|
||||
|
|
|
@ -246,7 +246,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.Sc
|
|||
return p.schemaMigrateFinalize(ctx, migration)
|
||||
}
|
||||
|
||||
func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
|
||||
func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
|
||||
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ const (
|
|||
const (
|
||||
period = "."
|
||||
https = "https"
|
||||
wss = "wss"
|
||||
)
|
||||
|
||||
// X.509 consts.
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
||||
// describing why.
|
||||
func IsStringAbsURL(input string) (err error) {
|
||||
parsedURL, err := url.Parse(input)
|
||||
parsedURL, err := url.ParseRequestURI(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse '%s' as a URL", input)
|
||||
}
|
||||
|
|
|
@ -235,13 +235,13 @@ func TestStringSliceURLConversionFuncs(t *testing.T) {
|
|||
func TestIsURLInSlice(t *testing.T) {
|
||||
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
|
||||
|
||||
google, err := url.Parse("https://google.com")
|
||||
google, err := url.ParseRequestURI("https://google.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
microsoft, err := url.Parse("https://microsoft.com")
|
||||
microsoft, err := url.ParseRequestURI("https://microsoft.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
example, err := url.Parse("https://example.com")
|
||||
example, err := url.ParseRequestURI("https://example.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, IsURLInSlice(*google, urls))
|
||||
|
|
|
@ -30,12 +30,37 @@ func URLPathFullClean(u *url.URL) (output string) {
|
|||
}
|
||||
}
|
||||
|
||||
// URLDomainHasSuffix determines whether the uri has a suffix of the domain value.
|
||||
func URLDomainHasSuffix(uri url.URL, domain string) bool {
|
||||
if uri.Scheme != https {
|
||||
return false
|
||||
// IsURIStringSafeRedirection determines whether the URI is safe to be redirected to.
|
||||
func IsURIStringSafeRedirection(uri, protectedDomain string) (safe bool, err error) {
|
||||
var parsedURI *url.URL
|
||||
|
||||
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
|
||||
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
|
||||
}
|
||||
|
||||
return parsedURI != nil && IsURISafeRedirection(parsedURI, protectedDomain), nil
|
||||
}
|
||||
|
||||
// IsURISafeRedirection returns true if the URI passes the IsURISecure and HasURIDomainSuffix, i.e. if the scheme is
|
||||
// secure and the given URI has a hostname that is either exactly equal to the given domain or if it has a suffix of the
|
||||
// domain prefixed with a period.
|
||||
func IsURISafeRedirection(uri *url.URL, domain string) bool {
|
||||
return IsURISecure(uri) && HasURIDomainSuffix(uri, domain)
|
||||
}
|
||||
|
||||
// IsURISecure returns true if the URI has a secure schemes (https or wss).
|
||||
func IsURISecure(uri *url.URL) bool {
|
||||
switch uri.Scheme {
|
||||
case https, wss:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// HasURIDomainSuffix returns true if the URI hostname is equal to the domain or if it has a suffix of the domain
|
||||
// prefixed with a period.
|
||||
func HasURIDomainSuffix(uri *url.URL, domain string) bool {
|
||||
if uri.Hostname() == domain {
|
||||
return true
|
||||
}
|
||||
|
@ -46,14 +71,3 @@ func URLDomainHasSuffix(uri url.URL, domain string) bool {
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsRedirectionURISafe determines whether the URI is safe to be redirected to.
|
||||
func IsRedirectionURISafe(uri, protectedDomain string) (safe bool, err error) {
|
||||
var parsedURI *url.URL
|
||||
|
||||
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
|
||||
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
|
||||
}
|
||||
|
||||
return parsedURI != nil && URLDomainHasSuffix(*parsedURI, protectedDomain), nil
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestURLPathFullClean(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
u, err := url.Parse(tc.have)
|
||||
u, err := url.ParseRequestURI(tc.have)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := URLPathFullClean(u)
|
||||
|
@ -41,7 +41,7 @@ func TestURLPathFullClean(t *testing.T) {
|
|||
|
||||
func isURLSafe(requestURI string, domain string) bool { //nolint:unparam
|
||||
u, _ := url.ParseRequestURI(requestURI)
|
||||
return URLDomainHasSuffix(*u, domain)
|
||||
return IsURISafeRedirection(u, domain)
|
||||
}
|
||||
|
||||
func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) {
|
||||
|
@ -62,22 +62,22 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) {
|
||||
_, err := IsRedirectionURISafe("http//invalid", "example.com")
|
||||
_, err := IsURIStringSafeRedirection("http//invalid", "example.com")
|
||||
assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request")
|
||||
}
|
||||
|
||||
func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) {
|
||||
valid, err := IsRedirectionURISafe("http://myurl.com/myresource", "example.com")
|
||||
valid, err := IsURIStringSafeRedirection("http://myurl.com/myresource", "example.com")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
}
|
||||
|
||||
func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) {
|
||||
valid, err := IsRedirectionURISafe("http://myurl.example.com/myresource", "example.com")
|
||||
valid, err := IsURIStringSafeRedirection("http://myurl.example.com/myresource", "example.com")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
|
||||
valid, err = IsRedirectionURISafe("http://example.com/myresource", "example.com")
|
||||
valid, err = IsURIStringSafeRedirection("http://example.com/myresource", "example.com")
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue