refactor(oidc): simplify hmac core strategy (#4711)

pull/4712/head
James Elliott 2023-01-07 10:28:53 +11:00 committed by GitHub
parent 680d502295
commit f223975e79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 20 deletions

View File

@ -40,11 +40,10 @@ func NewConfig(config *schema.OpenIDConnectConfiguration) *Config {
}, },
} }
prefix := "authelia_%s_"
c.Strategy.Core = &HMACCoreStrategy{ c.Strategy.Core = &HMACCoreStrategy{
Enigma: &hmac.HMACStrategy{Config: c}, Enigma: &hmac.HMACStrategy{Config: c},
Config: c, Config: c,
prefix: &prefix, prefix: tokenPrefixFmt,
} }
return c return c

View File

@ -106,6 +106,13 @@ const (
JWTHeaderKeyIdentifier = "kid" JWTHeaderKeyIdentifier = "kid"
) )
const (
tokenPrefixFmt = "authelia_%s_" //nolint:gosec
tokenPrefixPartAccessToken = "at"
tokenPrefixPartRefreshToken = "rt"
tokenPrefixPartAuthorizeCode = "ac"
)
// Paths. // Paths.
const ( const (
EndpointPathConsent = "/consent" EndpointPathConsent = "/consent"

View File

@ -19,7 +19,7 @@ type HMACCoreStrategy struct {
fosite.RefreshTokenLifespanProvider fosite.RefreshTokenLifespanProvider
fosite.AuthorizeCodeLifespanProvider fosite.AuthorizeCodeLifespanProvider
} }
prefix *string prefix string
} }
// AccessTokenSignature implements oauth2.AccessTokenStrategy. // AccessTokenSignature implements oauth2.AccessTokenStrategy.
@ -34,7 +34,7 @@ func (h *HMACCoreStrategy) GenerateAccessToken(ctx context.Context, _ fosite.Req
return "", "", err return "", "", err
} }
return h.setPrefix(token, "at"), sig, nil return h.setPrefix(token, tokenPrefixPartAccessToken), sig, nil
} }
// ValidateAccessToken implements oauth2.AccessTokenStrategy. // ValidateAccessToken implements oauth2.AccessTokenStrategy.
@ -48,7 +48,7 @@ func (h *HMACCoreStrategy) ValidateAccessToken(ctx context.Context, r fosite.Req
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", exp)) return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Access token expired at '%s'.", exp))
} }
return h.Enigma.Validate(ctx, h.trimPrefix(token, "at")) return h.Enigma.Validate(ctx, h.trimPrefix(token, tokenPrefixPartAccessToken))
} }
// RefreshTokenSignature implements oauth2.RefreshTokenStrategy. // RefreshTokenSignature implements oauth2.RefreshTokenStrategy.
@ -63,21 +63,22 @@ func (h *HMACCoreStrategy) GenerateRefreshToken(ctx context.Context, _ fosite.Re
return "", "", err return "", "", err
} }
return h.setPrefix(token, "rt"), sig, nil return h.setPrefix(token, tokenPrefixPartRefreshToken), sig, nil
} }
// ValidateRefreshToken implements oauth2.RefreshTokenStrategy. // ValidateRefreshToken implements oauth2.RefreshTokenStrategy.
func (h *HMACCoreStrategy) ValidateRefreshToken(ctx context.Context, r fosite.Requester, token string) (err error) { func (h *HMACCoreStrategy) ValidateRefreshToken(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.RefreshToken) var exp = r.GetSession().GetExpiresAt(fosite.RefreshToken)
if exp.IsZero() { if exp.IsZero() {
return h.Enigma.Validate(ctx, h.trimPrefix(token, "rt")) return h.Enigma.Validate(ctx, h.trimPrefix(token, tokenPrefixPartRefreshToken))
} }
if !exp.IsZero() && exp.Before(time.Now().UTC()) { if exp.Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Refresh token expired at '%s'.", exp)) return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Refresh token expired at '%s'.", exp))
} }
return h.Enigma.Validate(ctx, h.trimPrefix(token, "rt")) return h.Enigma.Validate(ctx, h.trimPrefix(token, tokenPrefixPartRefreshToken))
} }
// AuthorizeCodeSignature implements oauth2.AuthorizeCodeStrategy. // AuthorizeCodeSignature implements oauth2.AuthorizeCodeStrategy.
@ -92,12 +93,13 @@ func (h *HMACCoreStrategy) GenerateAuthorizeCode(ctx context.Context, _ fosite.R
return "", "", err return "", "", err
} }
return h.setPrefix(token, "ac"), sig, nil return h.setPrefix(token, tokenPrefixPartAuthorizeCode), sig, nil
} }
// ValidateAuthorizeCode implements oauth2.AuthorizeCodeStrategy. // ValidateAuthorizeCode implements oauth2.AuthorizeCodeStrategy.
func (h *HMACCoreStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.Requester, token string) (err error) { func (h *HMACCoreStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.Requester, token string) (err error) {
var exp = r.GetSession().GetExpiresAt(fosite.AuthorizeCode) var exp = r.GetSession().GetExpiresAt(fosite.AuthorizeCode)
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx)).Before(time.Now().UTC()) { if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx)).Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx)))) return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetAuthorizeCodeLifespan(ctx))))
} }
@ -106,24 +108,21 @@ func (h *HMACCoreStrategy) ValidateAuthorizeCode(ctx context.Context, r fosite.R
return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", exp)) return errorsx.WithStack(fosite.ErrTokenExpired.WithHintf("Authorize code expired at '%s'.", exp))
} }
return h.Enigma.Validate(ctx, h.trimPrefix(token, "ac")) return h.Enigma.Validate(ctx, h.trimPrefix(token, tokenPrefixPartAuthorizeCode))
} }
func (h *HMACCoreStrategy) getPrefix(part string) string { func (h *HMACCoreStrategy) getPrefix(part string) string {
if h.prefix == nil { if len(h.prefix) == 0 {
prefix := "ory_%s_"
h.prefix = &prefix
} else if len(*h.prefix) == 0 {
return "" return ""
} }
return fmt.Sprintf(*h.prefix, part) return fmt.Sprintf(h.prefix, part)
}
func (h *HMACCoreStrategy) trimPrefix(token, part string) string {
return strings.TrimPrefix(token, h.getPrefix(part))
} }
func (h *HMACCoreStrategy) setPrefix(token, part string) string { func (h *HMACCoreStrategy) setPrefix(token, part string) string {
return h.getPrefix(part) + token return h.getPrefix(part) + token
} }
func (h *HMACCoreStrategy) trimPrefix(token, part string) string {
return strings.TrimPrefix(token, h.getPrefix(part))
}