refactor: clean up uri checking functions (#3943)
parent
02636966a8
commit
2325031052
|
@ -19,7 +19,7 @@ func (acs *AccessControlSubjects) AddSubject(subjectRule string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsMatch returns true if the ACL subjects match the subject properties.
|
// IsMatch returns true if the ACL subjects match the subject properties.
|
||||||
func (acs AccessControlSubjects) IsMatch(subject Subject) (match bool) {
|
func (acs *AccessControlSubjects) IsMatch(subject Subject) (match bool) {
|
||||||
for _, rule := range acs.Subjects {
|
for _, rule := range acs.Subjects {
|
||||||
if !rule.IsMatch(subject) {
|
if !rule.IsMatch(subject) {
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldAppendQueryParamToURL(t *testing.T) {
|
func TestShouldAppendQueryParamToURL(t *testing.T) {
|
||||||
targetURL, err := url.Parse("https://domain.example.com/api?type=none")
|
targetURL, err := url.ParseRequestURI("https://domain.example.com/api?type=none")
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ func TestShouldAppendQueryParamToURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldCreateNewObjectFromRaw(t *testing.T) {
|
func TestShouldCreateNewObjectFromRaw(t *testing.T) {
|
||||||
targetURL, err := url.Parse("https://domain.example.com/api")
|
targetURL, err := url.ParseRequestURI("https://domain.example.com/api")
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ func TestShouldCleanURL(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.have, func(t *testing.T) {
|
t.Run(tc.have, func(t *testing.T) {
|
||||||
have, err := url.Parse(tc.have + tc.havePath)
|
have, err := url.ParseRequestURI(tc.have + tc.havePath)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
object := NewObject(have, tc.method)
|
object := NewObject(have, tc.method)
|
||||||
|
@ -66,7 +66,7 @@ func TestShouldCleanURL(t *testing.T) {
|
||||||
assert.Equal(t, tc.expectedPathClean, object.Path)
|
assert.Equal(t, tc.expectedPathClean, object.Path)
|
||||||
assert.Equal(t, tc.method, object.Method)
|
assert.Equal(t, tc.method, object.Method)
|
||||||
|
|
||||||
have, err = url.Parse(tc.have)
|
have, err = url.ParseRequestURI(tc.have)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
path, err := url.ParseRequestURI(tc.havePath)
|
path, err := url.ParseRequestURI(tc.havePath)
|
||||||
|
|
|
@ -208,7 +208,7 @@ func getSubjectAndObjectFromFlags(cmd *cobra.Command) (subject authorization.Sub
|
||||||
return subject, object, err
|
return subject, object, err
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, err := url.Parse(requestURL)
|
parsedURL, err := url.ParseRequestURI(requestURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return subject, object, err
|
return subject, object, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ func TestShouldRaiseErrorWithBadDefaultRedirectionURL(t *testing.T) {
|
||||||
require.Len(t, validator.Errors(), 1)
|
require.Len(t, validator.Errors(), 1)
|
||||||
require.Len(t, validator.Warnings(), 1)
|
require.Len(t, validator.Warnings(), 1)
|
||||||
|
|
||||||
assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: the url 'bad_default_redirection_url' is not absolute because it doesn't start with a scheme like 'ldap://' or 'ldaps://'")
|
assert.EqualError(t, validator.Errors()[0], "option 'default_redirection_url' is invalid: could not parse 'bad_default_redirection_url' as a URL")
|
||||||
assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests")
|
assert.EqualError(t, validator.Warnings()[0], "access control: no rules have been specified so the 'default_policy' of 'two_factor' is going to be applied to all requests")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -95,7 +95,7 @@ func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfigura
|
||||||
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
|
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
|
||||||
for _, client := range config.Clients {
|
for _, client := range config.Clients {
|
||||||
for _, redirectURI := range client.RedirectURIs {
|
for _, redirectURI := range client.RedirectURIs {
|
||||||
uri, err := url.Parse(redirectURI)
|
uri, err := url.ParseRequestURI(redirectURI)
|
||||||
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
|
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -116,6 +116,7 @@ func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||||
invalidID, duplicateIDs := false, false
|
invalidID, duplicateIDs := false, false
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ func CheckSafeRedirectionPOST(ctx *middlewares.AutheliaCtx) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
safe, err := utils.IsRedirectionURISafe(reqBody.URI, ctx.Configuration.Session.Domain)
|
safe, err := utils.IsURIStringSafeRedirection(reqBody.URI, ctx.Configuration.Session.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed)
|
ctx.Error(fmt.Errorf("unable to determine if uri %s is safe to redirect to: %w", reqBody.URI, err), messageOperationFailed)
|
||||||
return
|
return
|
||||||
|
|
|
@ -31,9 +31,9 @@ func LogoutPOST(ctx *middlewares.AutheliaCtx) {
|
||||||
ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed)
|
ctx.Error(fmt.Errorf("unable to destroy session during logout: %s", err), messageOperationFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectionURL, err := url.Parse(body.TargetURL)
|
redirectionURL, err := url.ParseRequestURI(body.TargetURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
responseBody.SafeTargetURL = utils.URLDomainHasSuffix(*redirectionURL, ctx.Configuration.Session.Domain)
|
responseBody.SafeTargetURL = utils.IsURISafeRedirection(redirectionURL, ctx.Configuration.Session.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
if body.TargetURL != "" {
|
if body.TargetURL != "" {
|
||||||
|
|
|
@ -29,7 +29,7 @@ func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if issuer, err = url.Parse(rootURI); err != nil {
|
if issuer, err = url.ParseRequestURI(rootURI); err != nil {
|
||||||
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer."))
|
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not safely determine the issuer."))
|
||||||
|
|
||||||
return nil, true
|
return nil, true
|
||||||
|
@ -178,7 +178,7 @@ func handleOIDCAuthorizationConsentRedirect(ctx *middlewares.AutheliaCtx, issuer
|
||||||
var location *url.URL
|
var location *url.URL
|
||||||
|
|
||||||
if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
|
if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
|
||||||
location, _ = url.Parse(issuer.String())
|
location, _ = url.ParseRequestURI(issuer.String())
|
||||||
location.Path = path.Join(location.Path, "/consent")
|
location.Path = path.Join(location.Path, "/consent")
|
||||||
|
|
||||||
query := location.Query()
|
query := location.Query()
|
||||||
|
@ -229,9 +229,9 @@ func verifyOIDCUserAuthorizedForConsent(ctx *middlewares.AutheliaCtx, client *oi
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) {
|
func getOIDCAuthorizationRedirectURL(issuer *url.URL, requester fosite.AuthorizeRequester) (redirectURL *url.URL) {
|
||||||
redirectURL, _ = url.Parse(issuer.String())
|
redirectURL, _ = url.ParseRequestURI(issuer.String())
|
||||||
|
|
||||||
authorizationURL, _ := url.Parse(issuer.String())
|
authorizationURL, _ := url.ParseRequestURI(issuer.String())
|
||||||
|
|
||||||
authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath)
|
authorizationURL.Path = path.Join(authorizationURL.Path, oidc.AuthorizationPath)
|
||||||
authorizationURL.RawQuery = requester.GetRequestForm().Encode()
|
authorizationURL.RawQuery = requester.GetRequestForm().Encode()
|
||||||
|
|
|
@ -104,7 +104,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI, requestMethod st
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !utils.URLDomainHasSuffix(*targetURL, ctx.Configuration.Session.Domain) {
|
if !utils.IsURISafeRedirection(targetURL, ctx.Configuration.Session.Domain) {
|
||||||
ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI)
|
ctx.Logger.Debugf("Redirection URL %s is not safe", targetURI)
|
||||||
|
|
||||||
if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" {
|
if !ctx.Providers.Authorizer.IsSecondFactorEnabled() && ctx.Configuration.DefaultRedirectionURL != "" {
|
||||||
|
@ -147,7 +147,7 @@ func Handle2FAResponse(ctx *middlewares.AutheliaCtx, targetURI string) {
|
||||||
|
|
||||||
var safe bool
|
var safe bool
|
||||||
|
|
||||||
if safe, err = utils.IsRedirectionURISafe(targetURI, ctx.Configuration.Session.Domain); err != nil {
|
if safe, err = utils.IsURIStringSafeRedirection(targetURI, ctx.Configuration.Session.Domain); err != nil {
|
||||||
ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed)
|
ctx.Error(fmt.Errorf("unable to check target URL: %s", err), messageMFAValidationFailed)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -176,7 +176,7 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba
|
||||||
|
|
||||||
referer := ctx.Request.Header.Referer()
|
referer := ctx.Request.Header.Referer()
|
||||||
if referer != nil {
|
if referer != nil {
|
||||||
refererURL, err := url.Parse(string(referer))
|
refererURL, err := url.ParseRequestURI(string(referer))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
requestURI = refererURL.Query().Get("rd")
|
requestURI = refererURL.Query().Get("rd")
|
||||||
requestMethod = refererURL.Query().Get("rm")
|
requestMethod = refererURL.Query().Get("rm")
|
||||||
|
|
|
@ -195,7 +195,7 @@ func (ctx *AutheliaCtx) ExternalRootURL() (string, error) {
|
||||||
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
|
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
|
||||||
|
|
||||||
if base := ctx.BasePath(); base != "" {
|
if base := ctx.BasePath(); base != "" {
|
||||||
externalBaseURL, err := url.Parse(externalRootURL)
|
externalBaseURL, err := url.ParseRequestURI(externalRootURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -317,14 +317,14 @@ func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsXHR returns true if the request is a XMLHttpRequest.
|
// IsXHR returns true if the request is a XMLHttpRequest.
|
||||||
func (ctx AutheliaCtx) IsXHR() (xhr bool) {
|
func (ctx *AutheliaCtx) IsXHR() (xhr bool) {
|
||||||
requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
|
requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
|
||||||
|
|
||||||
return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
|
return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
|
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
|
||||||
func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
func (ctx *AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
||||||
accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
|
accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
|
||||||
|
|
||||||
for i, accept := range accepts {
|
for i, accept := range accepts {
|
||||||
|
|
|
@ -62,9 +62,14 @@ var (
|
||||||
headerValueCohort = []byte("interest-cohort=()")
|
headerValueCohort = []byte("interest-cohort=()")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
strProtoHTTPS = "https"
|
||||||
|
strProtoHTTP = "http"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
protoHTTPS = []byte("https")
|
protoHTTPS = []byte(strProtoHTTPS)
|
||||||
protoHTTP = []byte("http")
|
protoHTTP = []byte(strProtoHTTP)
|
||||||
|
|
||||||
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
||||||
UserValueKeyBaseURL = []byte("base_url")
|
UserValueKeyBaseURL = []byte("base_url")
|
||||||
|
|
|
@ -265,7 +265,7 @@ func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
|
||||||
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
||||||
|
|
||||||
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
||||||
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
|
if originURL, err = url.ParseRequestURI(string(origin)); err != nil || (originURL.Scheme != strProtoHTTPS && p.origins == nil) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -188,7 +188,7 @@ func (s *OAuth2Session) SetSubject(subject string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
|
// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
|
||||||
func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
|
func (s *OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
|
||||||
sessionData := s.Session
|
sessionData := s.Session
|
||||||
|
|
||||||
if session != nil {
|
if session != nil {
|
||||||
|
|
|
@ -23,7 +23,7 @@ type TOTPConfiguration struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// URI shows the configuration in the URI representation.
|
// URI shows the configuration in the URI representation.
|
||||||
func (c TOTPConfiguration) URI() (uri string) {
|
func (c *TOTPConfiguration) URI() (uri string) {
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Set("secret", string(c.Secret))
|
v.Set("secret", string(c.Secret))
|
||||||
v.Set("issuer", c.Issuer)
|
v.Set("issuer", c.Issuer)
|
||||||
|
@ -47,12 +47,12 @@ func (c *TOTPConfiguration) UpdateSignInInfo(now time.Time) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL.
|
// Key returns the *otp.Key using TOTPConfiguration.URI with otp.NewKeyFromURL.
|
||||||
func (c TOTPConfiguration) Key() (key *otp.Key, err error) {
|
func (c *TOTPConfiguration) Key() (key *otp.Key, err error) {
|
||||||
return otp.NewKeyFromURL(c.URI())
|
return otp.NewKeyFromURL(c.URI())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key.
|
// Image returns the image.Image of the TOTPConfiguration using the Image func from the return of TOTPConfiguration.Key.
|
||||||
func (c TOTPConfiguration) Image(width, height int) (img image.Image, err error) {
|
func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error) {
|
||||||
var key *otp.Key
|
var key *otp.Key
|
||||||
|
|
||||||
if key, err = c.Key(); err != nil {
|
if key, err = c.Key(); err != nil {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ory/fosite/token/jwt"
|
"github.com/ory/fosite/token/jwt"
|
||||||
"gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
|
@ -38,17 +38,17 @@ func NewKeyManager() (manager *KeyManager) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strategy returns the RS256JWTStrategy.
|
// Strategy returns the RS256JWTStrategy.
|
||||||
func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) {
|
func (m *KeyManager) Strategy() (strategy *RS256JWTStrategy) {
|
||||||
return m.strategy
|
return m.strategy
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
|
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
|
||||||
func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
|
func (m *KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
|
||||||
return m.keySet
|
return m.keySet
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveWebKey obtains the currently active jose.JSONWebKey.
|
// GetActiveWebKey obtains the currently active jose.JSONWebKey.
|
||||||
func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
func (m *KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
||||||
webKeys := m.keySet.Key(m.activeKeyID)
|
webKeys := m.keySet.Key(m.activeKeyID)
|
||||||
if len(webKeys) == 1 {
|
if len(webKeys) == 1 {
|
||||||
return &webKeys[0], nil
|
return &webKeys[0], nil
|
||||||
|
@ -62,12 +62,12 @@ func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveKeyID returns the key id of the currently active key.
|
// GetActiveKeyID returns the key id of the currently active key.
|
||||||
func (m KeyManager) GetActiveKeyID() (keyID string) {
|
func (m *KeyManager) GetActiveKeyID() (keyID string) {
|
||||||
return m.activeKeyID
|
return m.activeKeyID
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveKey returns the rsa.PublicKey of the currently active key.
|
// GetActiveKey returns the rsa.PublicKey of the currently active key.
|
||||||
func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
func (m *KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
||||||
if key, ok := m.keys[m.activeKeyID]; ok {
|
if key, ok := m.keys[m.activeKeyID]; ok {
|
||||||
return &key.PublicKey, nil
|
return &key.PublicKey, nil
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
|
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
|
||||||
func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
|
func (m *KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
|
||||||
if key, ok := m.keys[m.activeKeyID]; ok {
|
if key, ok := m.keys[m.activeKeyID]; ok {
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider s
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
|
// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
|
||||||
func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
func (s *OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
||||||
if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
|
if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if opaqueID == nil {
|
} else if opaqueID == nil {
|
||||||
|
@ -55,7 +55,7 @@ func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
|
// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
|
||||||
func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
|
func (s *OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
|
||||||
var opaqueID *model.UserOpaqueIdentifier
|
var opaqueID *model.UserOpaqueIdentifier
|
||||||
|
|
||||||
if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
|
if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
|
||||||
|
@ -66,7 +66,7 @@ func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username s
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPolicy retrieves the policy from the client with the matching provided id.
|
// GetClientPolicy retrieves the policy from the client with the matching provided id.
|
||||||
func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
|
func (s *OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
|
||||||
client, err := s.GetFullClient(id)
|
client, err := s.GetFullClient(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return authorization.TwoFactor
|
return authorization.TwoFactor
|
||||||
|
@ -76,7 +76,7 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
|
// GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
|
||||||
func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
|
func (s *OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
|
||||||
client, ok := s.clients[id]
|
client, ok := s.clients[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fosite.ErrNotFound
|
return nil, fosite.ErrNotFound
|
||||||
|
@ -86,7 +86,7 @@ func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
|
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
|
||||||
func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
|
func (s *OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
|
||||||
_, err := s.GetFullClient(id)
|
_, err := s.GetFullClient(id)
|
||||||
|
|
||||||
return err == nil
|
return err == nil
|
||||||
|
|
|
@ -61,7 +61,7 @@ func (s *UserSession) SetTwoFactorWebauthn(now time.Time, userPresence, userVeri
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level.
|
// AuthenticatedTime returns the unix timestamp this session authenticated successfully at the given level.
|
||||||
func (s UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) {
|
func (s *UserSession) AuthenticatedTime(level authorization.Level) (authenticatedTime time.Time, err error) {
|
||||||
switch level {
|
switch level {
|
||||||
case authorization.OneFactor:
|
case authorization.OneFactor:
|
||||||
return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil
|
return time.Unix(s.FirstFactorAuthnTimestamp, 0), nil
|
||||||
|
|
|
@ -139,7 +139,7 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
|
if err = p.schemaEncryptionCheckWebauthn(ctx); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -210,9 +210,9 @@ func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
|
func (p *SQLProvider) schemaEncryptionCheckWebauthn(ctx context.Context) (err error) {
|
||||||
var (
|
var (
|
||||||
device model.U2FDevice
|
device model.WebauthnDevice
|
||||||
row int
|
row int
|
||||||
invalid int
|
invalid int
|
||||||
total int
|
total int
|
||||||
|
@ -226,7 +226,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
||||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
|
||||||
_ = rows.Close()
|
_ = rows.Close()
|
||||||
|
|
||||||
return fmt.Errorf("error selecting U2F devices: %w", err)
|
return fmt.Errorf("error selecting Webauthn devices: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
row = 0
|
row = 0
|
||||||
|
@ -237,7 +237,7 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
||||||
|
|
||||||
if err = rows.StructScan(&device); err != nil {
|
if err = rows.StructScan(&device); err != nil {
|
||||||
_ = rows.Close()
|
_ = rows.Close()
|
||||||
return fmt.Errorf("error scanning U2F device to struct: %w", err)
|
return fmt.Errorf("error scanning Webauthn device to struct: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = p.decrypt(device.PublicKey); err != nil {
|
if _, err = p.decrypt(device.PublicKey); err != nil {
|
||||||
|
@ -253,17 +253,17 @@ func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalid != 0 {
|
if invalid != 0 {
|
||||||
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
|
return fmt.Errorf("%d of %d total Webauthn devices were invalid", invalid, total)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
func (p *SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||||
return utils.Encrypt(clearText, &p.key)
|
return utils.Encrypt(clearText, &p.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
func (p *SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
|
||||||
return utils.Decrypt(cipherText, &p.key)
|
return utils.Decrypt(cipherText, &p.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -246,7 +246,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration model.Sc
|
||||||
return p.schemaMigrateFinalize(ctx, migration)
|
return p.schemaMigrateFinalize(ctx, migration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
|
func (p *SQLProvider) schemaMigrateFinalize(ctx context.Context, migration model.SchemaMigration) (err error) {
|
||||||
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
|
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ const (
|
||||||
const (
|
const (
|
||||||
period = "."
|
period = "."
|
||||||
https = "https"
|
https = "https"
|
||||||
|
wss = "wss"
|
||||||
)
|
)
|
||||||
|
|
||||||
// X.509 consts.
|
// X.509 consts.
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
||||||
// describing why.
|
// describing why.
|
||||||
func IsStringAbsURL(input string) (err error) {
|
func IsStringAbsURL(input string) (err error) {
|
||||||
parsedURL, err := url.Parse(input)
|
parsedURL, err := url.ParseRequestURI(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not parse '%s' as a URL", input)
|
return fmt.Errorf("could not parse '%s' as a URL", input)
|
||||||
}
|
}
|
||||||
|
|
|
@ -235,13 +235,13 @@ func TestStringSliceURLConversionFuncs(t *testing.T) {
|
||||||
func TestIsURLInSlice(t *testing.T) {
|
func TestIsURLInSlice(t *testing.T) {
|
||||||
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
|
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
|
||||||
|
|
||||||
google, err := url.Parse("https://google.com")
|
google, err := url.ParseRequestURI("https://google.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
microsoft, err := url.Parse("https://microsoft.com")
|
microsoft, err := url.ParseRequestURI("https://microsoft.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
example, err := url.Parse("https://example.com")
|
example, err := url.ParseRequestURI("https://example.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, IsURLInSlice(*google, urls))
|
assert.True(t, IsURLInSlice(*google, urls))
|
||||||
|
|
|
@ -30,12 +30,37 @@ func URLPathFullClean(u *url.URL) (output string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// URLDomainHasSuffix determines whether the uri has a suffix of the domain value.
|
// IsURIStringSafeRedirection determines whether the URI is safe to be redirected to.
|
||||||
func URLDomainHasSuffix(uri url.URL, domain string) bool {
|
func IsURIStringSafeRedirection(uri, protectedDomain string) (safe bool, err error) {
|
||||||
if uri.Scheme != https {
|
var parsedURI *url.URL
|
||||||
return false
|
|
||||||
|
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
|
||||||
|
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return parsedURI != nil && IsURISafeRedirection(parsedURI, protectedDomain), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsURISafeRedirection returns true if the URI passes the IsURISecure and HasURIDomainSuffix, i.e. if the scheme is
|
||||||
|
// secure and the given URI has a hostname that is either exactly equal to the given domain or if it has a suffix of the
|
||||||
|
// domain prefixed with a period.
|
||||||
|
func IsURISafeRedirection(uri *url.URL, domain string) bool {
|
||||||
|
return IsURISecure(uri) && HasURIDomainSuffix(uri, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsURISecure returns true if the URI has a secure schemes (https or wss).
|
||||||
|
func IsURISecure(uri *url.URL) bool {
|
||||||
|
switch uri.Scheme {
|
||||||
|
case https, wss:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasURIDomainSuffix returns true if the URI hostname is equal to the domain or if it has a suffix of the domain
|
||||||
|
// prefixed with a period.
|
||||||
|
func HasURIDomainSuffix(uri *url.URL, domain string) bool {
|
||||||
if uri.Hostname() == domain {
|
if uri.Hostname() == domain {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -46,14 +71,3 @@ func URLDomainHasSuffix(uri url.URL, domain string) bool {
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsRedirectionURISafe determines whether the URI is safe to be redirected to.
|
|
||||||
func IsRedirectionURISafe(uri, protectedDomain string) (safe bool, err error) {
|
|
||||||
var parsedURI *url.URL
|
|
||||||
|
|
||||||
if parsedURI, err = url.ParseRequestURI(uri); err != nil {
|
|
||||||
return false, fmt.Errorf("failed to parse URI '%s': %w", uri, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsedURI != nil && URLDomainHasSuffix(*parsedURI, protectedDomain), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestURLPathFullClean(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
u, err := url.Parse(tc.have)
|
u, err := url.ParseRequestURI(tc.have)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual := URLPathFullClean(u)
|
actual := URLPathFullClean(u)
|
||||||
|
@ -41,7 +41,7 @@ func TestURLPathFullClean(t *testing.T) {
|
||||||
|
|
||||||
func isURLSafe(requestURI string, domain string) bool { //nolint:unparam
|
func isURLSafe(requestURI string, domain string) bool { //nolint:unparam
|
||||||
u, _ := url.ParseRequestURI(requestURI)
|
u, _ := url.ParseRequestURI(requestURI)
|
||||||
return URLDomainHasSuffix(*u, domain)
|
return IsURISafeRedirection(u, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) {
|
func TestIsRedirectionSafe_ShouldReturnTrueOnExactDomain(t *testing.T) {
|
||||||
|
@ -62,22 +62,22 @@ func TestIsRedirectionSafe_ShouldReturnFalseOnBadDomain(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) {
|
func TestIsRedirectionURISafe_CannotParseURI(t *testing.T) {
|
||||||
_, err := IsRedirectionURISafe("http//invalid", "example.com")
|
_, err := IsURIStringSafeRedirection("http//invalid", "example.com")
|
||||||
assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request")
|
assert.EqualError(t, err, "failed to parse URI 'http//invalid': parse \"http//invalid\": invalid URI for request")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) {
|
func TestIsRedirectionURISafe_InvalidRedirectionURI(t *testing.T) {
|
||||||
valid, err := IsRedirectionURISafe("http://myurl.com/myresource", "example.com")
|
valid, err := IsURIStringSafeRedirection("http://myurl.com/myresource", "example.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, valid)
|
assert.False(t, valid)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) {
|
func TestIsRedirectionURISafe_ValidRedirectionURI(t *testing.T) {
|
||||||
valid, err := IsRedirectionURISafe("http://myurl.example.com/myresource", "example.com")
|
valid, err := IsURIStringSafeRedirection("http://myurl.example.com/myresource", "example.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, valid)
|
assert.False(t, valid)
|
||||||
|
|
||||||
valid, err = IsRedirectionURISafe("http://example.com/myresource", "example.com")
|
valid, err = IsURIStringSafeRedirection("http://example.com/myresource", "example.com")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.False(t, valid)
|
assert.False(t, valid)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue