refactor: clean up uri checking functions (#3943)

pull/3949/head
James Elliott 2022-09-03 11:51:02 +10:00 committed by GitHub
parent 02636966a8
commit 2325031052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 96 additions and 75 deletions

View File

@ -19,7 +19,7 @@ func (acs *AccessControlSubjects) AddSubject(subjectRule string) {
} }
// IsMatch returns true if the ACL subjects match the subject properties. // IsMatch returns true if the ACL subjects match the subject properties.
func (acs AccessControlSubjects) IsMatch(subject Subject) (match bool) { func (acs *AccessControlSubjects) IsMatch(subject Subject) (match bool) {
for _, rule := range acs.Subjects { for _, rule := range acs.Subjects {
if !rule.IsMatch(subject) { if !rule.IsMatch(subject) {
return false return false

View File

@ -9,7 +9,7 @@ import (
) )
func TestShouldAppendQueryParamToURL(t *testing.T) { func TestShouldAppendQueryParamToURL(t *testing.T) {
targetURL, err := url.Parse("https://domain.example.com/api?type=none") targetURL, err := url.ParseRequestURI("https://domain.example.com/api?type=none")
require.NoError(t, err) require.NoError(t, err)
@ -22,7 +22,7 @@ func TestShouldAppendQueryParamToURL(t *testing.T) {
} }
func TestShouldCreateNewObjectFromRaw(t *testing.T) { func TestShouldCreateNewObjectFromRaw(t *testing.T) {
targetURL, err := url.Parse("https://domain.example.com/api") targetURL, err := url.ParseRequestURI("https://domain.example.com/api")
require.NoError(t, err) require.NoError(t, err)
@ -55,7 +55,7 @@ func TestShouldCleanURL(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.have, func(t *testing.T) { t.Run(tc.have, func(t *testing.T) {
have, err := url.Parse(tc.have + tc.havePath) have, err := url.ParseRequestURI(tc.have + tc.havePath)
require.NoError(t, err) require.NoError(t, err)
object := NewObject(have, tc.method) object := NewObject(have, tc.method)
@ -66,7 +66,7 @@ func TestShouldCleanURL(t *testing.T) {
assert.Equal(t, tc.expectedPathClean, object.Path) assert.Equal(t, tc.expectedPathClean, object.Path)
assert.Equal(t, tc.method, object.Method) assert.Equal(t, tc.method, object.Method)
have, err = url.Parse(tc.have) have, err = url.ParseRequestURI(tc.have)
require.NoError(t, err) require.NoError(t, err)
path, err := url.ParseRequestURI(tc.havePath) path, err := url.ParseRequestURI(tc.havePath)

View File

@ -208,7 +208,7 @@ func getSubjectAndObjectFromFlags(cmd *cobra.Command) (subject authorization.Sub
return subject, object, err return subject, object, err
} }
parsedURL, err := url.Parse(requestURL) parsedURL, err := url.ParseRequestURI(requestURL)
if err != nil { if err != nil {
return subject, object, err return subject, object, err
} }

View File

@ -99,7 +99,7 @@ func TestShouldRaiseErrorWithBadDefaultRedirectionURL(t *testing.T) {
require.Len(t, validator.Errors(), 1) require.Len(t, validator.Errors(), 1)
require.Len(t, validator.Warnings(), 1) require.Len(t, validator.Warnings(), 1)
assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: the url 'bad_default_redirection_url' is not absolute because it doesn't start with a scheme like 'ldap://' or 'ldaps://'") assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: could not parse 'bad_default_redirection_url' as a URL")
assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests") assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests")
} }

View File

@ -95,7 +95,7 @@ func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfigura
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) { func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
for _, client := range config.Clients { for _, client := range config.Clients {
for _, redirectURI := range client.RedirectURIs { for _, redirectURI := range client.RedirectURIs {
uri, err := url.Parse(redirectURI) uri, err := url.ParseRequestURI(redirectURI)
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" { if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
continue continue
} }
@ -116,6 +116,7 @@ func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration,
} }
} }
} }
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) { func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
invalidID, duplicateIDs := false, false invalidID, duplicateIDs := false, false

View File

@ -25,7 +25,7 @@ func CheckSafeRedirectionPOST(ctx *middlewares.AutheliaCtx) {
return return
} }
safe, err := utils.IsRedirectionURISafe(reqBody.URI, ctx.Configuration.Session.Domain) safe, err := utils.IsURIStringSafeRedirection(reqBody.URI, ctx.Configuration.Session.Domain)
if err != nil { if err != nil {
ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed) ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed)
return return

View File

@ -31,9 +31,9 @@ func LogoutPOST(ctx *middlewares.AutheliaCtx) {
ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed) ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed)
} }
redirectionURL, err := url.Parse(body.TargetURL) redirectionURL, err := url.ParseRequestURI(body.TargetURL)
if err == nil { if err == nil {
responseBody.SafeTargetURL = utils.URLDomainHasSuffix(*redirectionURL, ctx.Configuration.Session.Domain) responseBody.SafeTargetURL = utils.IsURISafeRedirection(redirectionURL, ctx.Configuration.Session.Domain)
} }
if body.TargetURL != "" { if body.TargetURL != "" {

View File

@ -29,7 +29,7 @@ func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string
err error err error
) )
if issuer, err = url.Parse(rootURI); err != nil { if issuer, err = url.ParseRequestURI(rootURI); err != nil {
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer.")) ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer."))
return nil, true return nil, true
@ -178,7 +178,7 @@ func handleOIDCAuthorizationConsentRedirect(ctx *middlewares.AutheliaCtx, issuer
var location *url.URL var location *url.URL
if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
location, _ = url.Parse(issuer.String()) location, _ = url.ParseRequestURI(issuer.String())
location.Path = path.Join(location.Path, "/consent") location.Path = path.Join(location.Path, "/consent")
query := location.Query() query := location.Query()
@ -229,9 +229,9 @@ func verifyOIDCUserAuthorizedForConsent(ctx *middlewares.AutheliaCtx, client *oi
} }
func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) { func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) {
redirectURL, _ = url.Parse(issuer.String()) redirectURL, _ = url.ParseRequestURI(issuer.String())
authorizationURL, _ := url.Parse(issuer.String()) authorizationURL, _ := url.ParseRequestURI(issuer.String())
authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath) authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath)
authorizationURL.RawQuery = requester.GetRequestForm().Encode() authorizationURL.RawQuery = requester.GetRequestForm().Encode()

View File

@ -104,7 +104,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI, requestMethod st
return return
} }
if !utils.URLDomainHasSuffix(*targetURL, ctx.Configuration.Session.Domain) { if !utils.IsURISafeRedirection(targetURL, ctx.Configuration.Session.Domain) {
ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI) ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI)
if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" { if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" {
@ -147,7 +147,7 @@ func Handle2FAResponse(ctx *middlewares.AutheliaCtx, targetURI string) {
var safe bool var safe bool
if safe, err = utils.IsRedirectionURISafe(targetURI, ctx.Configuration.Session.Domain); err != nil { if safe, err = utils.IsURIStringSafeRedirection(targetURI, ctx.Configuration.Session.Domain); err != nil {
ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed) ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed)
return return
@ -176,7 +176,7 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba
referer := ctx.Request.Header.Referer() referer := ctx.Request.Header.Referer()
if referer != nil { if referer != nil {
refererURL, err := url.Parse(string(referer)) refererURL, err := url.ParseRequestURI(string(referer))
if err == nil { if err == nil {
requestURI = refererURL.Query().Get("rd") requestURI = refererURL.Query().Get("rd")
requestMethod = refererURL.Query().Get("rm") requestMethod = refererURL.Query().Get("rm")

View File

@ -195,7 +195,7 @@ func (ctx *AutheliaCtx) ExternalRootURL() (string, error) {
externalRootURL := fmt.Sprintf("%s://%s", protocol, host) externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
if base := ctx.BasePath(); base != "" { if base := ctx.BasePath(); base != "" {
externalBaseURL, err := url.Parse(externalRootURL) externalBaseURL, err := url.ParseRequestURI(externalRootURL)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -317,14 +317,14 @@ func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
} }
// IsXHR returns true if the request is a XMLHttpRequest. // IsXHR returns true if the request is a XMLHttpRequest.
func (ctx AutheliaCtx) IsXHR() (xhr bool) { func (ctx *AutheliaCtx) IsXHR() (xhr bool) {
requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith) requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR) return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
} }
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type. // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { func (ctx *AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",") accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
for i, accept := range accepts { for i, accept := range accepts {

View File

@ -62,9 +62,14 @@ var (
headerValueCohort = []byte("interest-cohort=()") headerValueCohort = []byte("interest-cohort=()")
) )
const (
strProtoHTTPS = "https"
strProtoHTTP = "http"
)
var ( var (
protoHTTPS = []byte("https") protoHTTPS = []byte(strProtoHTTPS)
protoHTTP = []byte("http") protoHTTP = []byte(strProtoHTTP)
// UserValueKeyBaseURL is the User Value key where we store the Base URL. // UserValueKeyBaseURL is the User Value key where we store the Base URL.
UserValueKeyBaseURL = []byte("base_url") UserValueKeyBaseURL = []byte("base_url")

View File

@ -265,7 +265,7 @@ func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
origin := ctx.Request.Header.PeekBytes(headerOrigin) origin := ctx.Request.Header.PeekBytes(headerOrigin)
// Skip processing of any `https` scheme URL that has not expressly been configured. // Skip processing of any `https` scheme URL that has not expressly been configured.
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) { if originURL, err = url.ParseRequestURI(string(origin)); err != nil || (originURL.Scheme != strProtoHTTPS && p.origins == nil) {
return return
} }

View File

@ -188,7 +188,7 @@ func (s *OAuth2Session) SetSubject(subject string) {
} }
// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage. // ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) { func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
sessionData := s.Session sessionData := s.Session
if session != nil { if session != nil {

View File

@ -23,7 +23,7 @@ type TOTPConfiguration struct {
} }
// URI shows the configuration in the URI representation. // URI shows the configuration in the URI representation.
func (c TOTPConfiguration) URI() (uri string) { func (c *TOTPConfiguration) URI() (uri string) {
v := url.Values{} v := url.Values{}
v.Set("secret", string(c.Secret)) v.Set("secret", string(c.Secret))
v.Set("issuer", c.Issuer) v.Set("issuer", c.Issuer)
@ -47,12 +47,12 @@ func (c *TOTPConfiguration) UpdateSignInInfo(now time.Time) {
} }
// Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL. // Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL.
func (c TOTPConfiguration) Key() (key *otp.Key, err error) { func (c *TOTPConfiguration) Key() (key *otp.Key, err error) {
return otp.NewKeyFromURL(c.URI()) return otp.NewKeyFromURL(c.URI())
} }
// Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key. // Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key.
func (c TOTPConfiguration) Image(width, height int) (img image.Image, err error) { func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error) {
var key *otp.Key var key *otp.Key
if key, err = c.Key(); err != nil { if key, err = c.Key(); err != nil {

View File

@ -9,7 +9,7 @@ import (
"strings" "strings"
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
"gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
@ -38,17 +38,17 @@ func NewKeyManager() (manager *KeyManager) {
} }
// Strategy returns the RS256JWTStrategy. // Strategy returns the RS256JWTStrategy.
func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) { func (m *KeyManager) Strategy() (strategy *RS256JWTStrategy) {
return m.strategy return m.strategy
} }
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types. // GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) { func (m *KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
return m.keySet return m.keySet
} }
// GetActiveWebKey obtains the currently active jose.JSONWebKey. // GetActiveWebKey obtains the currently active jose.JSONWebKey.
func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) { func (m *KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
webKeys := m.keySet.Key(m.activeKeyID) webKeys := m.keySet.Key(m.activeKeyID)
if len(webKeys) == 1 { if len(webKeys) == 1 {
return &webKeys[0], nil return &webKeys[0], nil
@ -62,12 +62,12 @@ func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
} }
// GetActiveKeyID returns the key id of the currently active key. // GetActiveKeyID returns the key id of the currently active key.
func (m KeyManager) GetActiveKeyID() (keyID string) { func (m *KeyManager) GetActiveKeyID() (keyID string) {
return m.activeKeyID return m.activeKeyID
} }
// GetActiveKey returns the rsa.PublicKey of the currently active key. // GetActiveKey returns the rsa.PublicKey of the currently active key.
func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) { func (m *KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
if key, ok := m.keys[m.activeKeyID]; ok { if key, ok := m.keys[m.activeKeyID]; ok {
return &key.PublicKey, nil return &key.PublicKey, nil
} }
@ -76,7 +76,7 @@ func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
} }
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key. // GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) { func (m *KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
if key, ok := m.keys[m.activeKeyID]; ok { if key, ok := m.keys[m.activeKeyID]; ok {
return key, nil return key, nil
} }

View File

@ -38,7 +38,7 @@ func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider s
} }
// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username. // GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { func (s *OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil { if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
return nil, err return nil, err
} else if opaqueID == nil { } else if opaqueID == nil {
@ -55,7 +55,7 @@ func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID,
} }
// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it. // GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) { func (s *OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
var opaqueID *model.UserOpaqueIdentifier var opaqueID *model.UserOpaqueIdentifier
if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil { if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
@ -66,7 +66,7 @@ func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username s
} }
// GetClientPolicy retrieves the policy from the client with the matching provided id. // GetClientPolicy retrieves the policy from the client with the matching provided id.
func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) { func (s *OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
client, err := s.GetFullClient(id) client, err := s.GetFullClient(id)
if err != nil { if err != nil {
return authorization.TwoFactor return authorization.TwoFactor
@ -76,7 +76,7 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve
} }
// GetFullClient returns a fosite.Client asserted as an Client matching the provided id. // GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) { func (s *OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
client, ok := s.clients[id] client, ok := s.clients[id]
if !ok { if !ok {
return nil, fosite.ErrNotFound return nil, fosite.ErrNotFound
@ -86,7 +86,7 @@ func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error)
} }
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map. // IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) { func (s *OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
_, err := s.GetFullClient(id) _, err := s.GetFullClient(id)
return err == nil return err == nil

View File

@ -61,7 +61,7 @@ func (s *UserSession) SetTwoFactorWebauthn(now time.Time, userPresence, userVeri
} }
// AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level. // AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level.
func (s UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) { func (s *UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) {
switch level { switch level {
case authorization.OneFactor: case authorization.OneFactor:
return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil

View File

@ -139,7 +139,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
errs = append(errs, err) errs = append(errs, err)
} }
if err = p.schemaEncryptionCheckU2F(ctx); err != nil { if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
} }
@ -210,9 +210,9 @@ func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error)
return nil return nil
} }
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) { func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) {
var ( var (
device model.U2FDevice device model.WebauthnDevice
row int row int
invalid int invalid int
total int total int
@ -226,7 +226,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil { if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
_ = rows.Close() _ = rows.Close()
return fmt.Errorf("error selecting U2F devices: %w", err) return fmt.Errorf("error selecting Webauthn devices: %w", err)
} }
row = 0 row = 0
@ -237,7 +237,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
if err = rows.StructScan(&device); err != nil { if err = rows.StructScan(&device); err != nil {
_ = rows.Close() _ = rows.Close()
return fmt.Errorf("error scanning U2F device to struct: %w", err) return fmt.Errorf("error scanning Webauthn device to struct: %w", err)
} }
if _, err = p.decrypt(device.PublicKey); err != nil { if _, err = p.decrypt(device.PublicKey); err != nil {
@ -253,17 +253,17 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
} }
if invalid != 0 { if invalid != 0 {
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total) return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total)
} }
return nil return nil
} }
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) { func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
return utils.Encrypt(clearText, &p.key) return utils.Encrypt(clearText, &p.key)
} }
func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) { func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
return utils.Decrypt(cipherText, &p.key) return utils.Decrypt(cipherText, &p.key)
} }

View File

@ -246,7 +246,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.Sc
return p.schemaMigrateFinalize(ctx, migration) return p.schemaMigrateFinalize(ctx, migration)
} }
func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) { func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After()) return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
} }

View File

@ -34,6 +34,7 @@ const (
const ( const (
period = "." period = "."
https = "https" https = "https"
wss = "wss"
) )
// X.509 consts. // X.509 consts.

View File

@ -16,7 +16,7 @@ import (
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error // IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
// describing why. // describing why.
func IsStringAbsURL(input string) (err error) { func IsStringAbsURL(input string) (err error) {
parsedURL, err := url.Parse(input) parsedURL, err := url.ParseRequestURI(input)
if err != nil { if err != nil {
return fmt.Errorf("could not parse '%s' as a URL", input) return fmt.Errorf("could not parse '%s' as a URL", input)
} }

View File

@ -235,13 +235,13 @@ func TestStringSliceURLConversionFuncs(t *testing.T) {
func TestIsURLInSlice(t *testing.T) { func TestIsURLInSlice(t *testing.T) {
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"}) urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
google, err := url.Parse("https://google.com") google, err := url.ParseRequestURI("https://google.com")
assert.NoError(t, err) assert.NoError(t, err)
microsoft, err := url.Parse("https://microsoft.com") microsoft, err := url.ParseRequestURI("https://microsoft.com")
assert.NoError(t, err) assert.NoError(t, err)
example, err := url.Parse("https://example.com") example, err := url.ParseRequestURI("https://example.com")
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, IsURLInSlice(*google, urls)) assert.True(t, IsURLInSlice(*google, urls))

View File

@ -30,12 +30,37 @@ func URLPathFullClean(u *url.URL) (output string) {
} }
} }
// URLDomainHasSuffix determines whether the uri has a suffix of the domain value. // IsURIStringSafeRedirection determines whether the URI is safe to be redirected to.
func URLDomainHasSuffix(uri url.URL, domain string) bool { func IsURIStringSafeRedirection(uri, protectedDomain string) (safe bool, err error) {
if uri.Scheme != https { var parsedURI *url.URL
return false
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
} }
return parsedURI != nil && IsURISafeRedirection(parsedURI, protectedDomain), nil
}
// IsURISafeRedirection returns true if the URI passes the IsURISecure and HasURIDomainSuffix, i.e. if the scheme is
// secure and the given URI has a hostname that is either exactly equal to the given domain or if it has a suffix of the
// domain prefixed with a period.
func IsURISafeRedirection(uri *url.URL, domain string) bool {
return IsURISecure(uri) && HasURIDomainSuffix(uri, domain)
}
// IsURISecure returns true if the URI has a secure schemes (https or wss).
func IsURISecure(uri *url.URL) bool {
switch uri.Scheme {
case https, wss:
return true
default:
return false
}
}
// HasURIDomainSuffix returns true if the URI hostname is equal to the domain or if it has a suffix of the domain
// prefixed with a period.
func HasURIDomainSuffix(uri *url.URL, domain string) bool {
if uri.Hostname() == domain { if uri.Hostname() == domain {
return true return true
} }
@ -46,14 +71,3 @@ func URLDomainHasSuffix(uri url.URL, domain string) bool {
return false return false
} }
// IsRedirectionURISafe determines whether the URI is safe to be redirected to.
func IsRedirectionURISafe(uri, protectedDomain string) (safe bool, err error) {
var parsedURI *url.URL
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
}
return parsedURI != nil && URLDomainHasSuffix(*parsedURI, protectedDomain), nil
}

View File

@ -29,7 +29,7 @@ func TestURLPathFullClean(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
u, err := url.Parse(tc.have) u, err := url.ParseRequestURI(tc.have)
require.NoError(t, err) require.NoError(t, err)
actual := URLPathFullClean(u) actual := URLPathFullClean(u)
@ -41,7 +41,7 @@ func TestURLPathFullClean(t *testing.T) {
func isURLSafe(requestURI string, domain string) bool { //nolint:unparam func isURLSafe(requestURI string, domain string) bool { //nolint:unparam
u, _ := url.ParseRequestURI(requestURI) u, _ := url.ParseRequestURI(requestURI)
return URLDomainHasSuffix(*u, domain) return IsURISafeRedirection(u, domain)
} }
func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) { func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) {
@ -62,22 +62,22 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
} }
func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) { func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) {
_, err := IsRedirectionURISafe("http//invalid", "example.com") _, err := IsURIStringSafeRedirection("http//invalid", "example.com")
assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request") assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request")
} }
func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) { func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) {
valid, err := IsRedirectionURISafe("http://myurl.com/myresource", "example.com") valid, err := IsURIStringSafeRedirection("http://myurl.com/myresource", "example.com")
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, valid) assert.False(t, valid)
} }
func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) { func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) {
valid, err := IsRedirectionURISafe("http://myurl.example.com/myresource", "example.com") valid, err := IsURIStringSafeRedirection("http://myurl.example.com/myresource", "example.com")
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, valid) assert.False(t, valid)
valid, err = IsRedirectionURISafe("http://example.com/myresource", "example.com") valid, err = IsURIStringSafeRedirection("http://example.com/myresource", "example.com")
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, valid) assert.False(t, valid)
} }