484 lines
13 KiB
Go
484 lines
13 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
fjwt "github.com/ory/fosite/token/jwt"
|
|
"github.com/ory/x/errorsx"
|
|
|
|
"gopkg.in/square/go-jose.v2"
|
|
|
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
|
)
|
|
|
|
// NewKeyManager news up a KeyManager.
|
|
func NewKeyManager(config *schema.OpenIDConnectConfiguration) (manager *KeyManager) {
|
|
manager = &KeyManager{
|
|
kids: map[string]*JWK{},
|
|
algs: map[string]*JWK{},
|
|
}
|
|
|
|
for _, sjwk := range config.IssuerPrivateKeys {
|
|
jwk := NewJWK(sjwk)
|
|
|
|
manager.kids[sjwk.KeyID] = jwk
|
|
manager.algs[jwk.alg.Alg()] = jwk
|
|
|
|
if jwk.kid == config.Discovery.DefaultKeyID {
|
|
manager.kid = jwk.kid
|
|
}
|
|
}
|
|
|
|
return manager
|
|
}
|
|
|
|
// The KeyManager type handles JWKs and signing operations.
|
|
type KeyManager struct {
|
|
kid string
|
|
kids map[string]*JWK
|
|
algs map[string]*JWK
|
|
}
|
|
|
|
// GetKeyID returns the default key id.
|
|
func (m *KeyManager) GetKeyID(ctx context.Context) string {
|
|
return m.kid
|
|
}
|
|
|
|
// GetKeyIDFromAlgStrict returns the key id given an alg or an error if it doesn't exist.
|
|
func (m *KeyManager) GetKeyIDFromAlgStrict(ctx context.Context, alg string) (kid string, err error) {
|
|
if jwks, ok := m.algs[alg]; ok {
|
|
return jwks.kid, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("alg not found")
|
|
}
|
|
|
|
// GetKeyIDFromAlg returns the key id given an alg or the default if it doesn't exist.
|
|
func (m *KeyManager) GetKeyIDFromAlg(ctx context.Context, alg string) string {
|
|
if jwks, ok := m.algs[alg]; ok {
|
|
return jwks.kid
|
|
}
|
|
|
|
return m.kid
|
|
}
|
|
|
|
// GetByAlg returns the JWK given an alg or nil if it doesn't exist.
|
|
func (m *KeyManager) GetByAlg(ctx context.Context, alg string) *JWK {
|
|
if jwk, ok := m.algs[alg]; ok {
|
|
return jwk
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByKID returns the JWK given an key id or nil if it doesn't exist. If given a blank string it returns the default.
|
|
func (m *KeyManager) GetByKID(ctx context.Context, kid string) *JWK {
|
|
if kid == "" {
|
|
return m.kids[m.kid]
|
|
}
|
|
|
|
if jwk, ok := m.kids[kid]; ok {
|
|
return jwk
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByHeader returns the JWK a JWT header with the appropriate kid value or returns an error.
|
|
func (m *KeyManager) GetByHeader(ctx context.Context, header fjwt.Mapper) (jwk *JWK, err error) {
|
|
var (
|
|
kid string
|
|
ok bool
|
|
)
|
|
|
|
if header == nil {
|
|
return nil, fmt.Errorf("jwt header was nil")
|
|
}
|
|
|
|
if kid, ok = header.Get(JWTHeaderKeyIdentifier).(string); !ok {
|
|
return nil, fmt.Errorf("jwt header did not have a kid")
|
|
}
|
|
|
|
if jwk, ok = m.kids[kid]; !ok {
|
|
return nil, fmt.Errorf("jwt header '%s' with value '%s' does not match a managed jwk", JWTHeaderKeyIdentifier, kid)
|
|
}
|
|
|
|
return jwk, nil
|
|
}
|
|
|
|
// GetByTokenString does an invalidated decode of a token to get the header, then calls GetByHeader.
|
|
func (m *KeyManager) GetByTokenString(ctx context.Context, tokenString string) (jwk *JWK, err error) {
|
|
var (
|
|
token *jwt.Token
|
|
)
|
|
|
|
if token, _, err = jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m.GetByHeader(ctx, &fjwt.Headers{Extra: token.Header})
|
|
}
|
|
|
|
// Set returns the *jose.JSONWebKeySet.
|
|
func (m *KeyManager) Set(ctx context.Context) *jose.JSONWebKeySet {
|
|
keys := make([]jose.JSONWebKey, 0, len(m.kids))
|
|
|
|
for _, jwk := range m.kids {
|
|
keys = append(keys, jwk.JWK())
|
|
}
|
|
|
|
sort.Sort(SortedJSONWebKey(keys))
|
|
|
|
return &jose.JSONWebKeySet{
|
|
Keys: keys,
|
|
}
|
|
}
|
|
|
|
// Generate implements the fosite jwt.Signer interface and automatically maps the underlying keys based on the JWK Header kid.
|
|
func (m *KeyManager) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByHeader(ctx, header); err != nil {
|
|
return "", "", fmt.Errorf("error getting jwk from header: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Generate(ctx, claims, header)
|
|
}
|
|
|
|
// Validate implements the fosite jwt.Signer interface and automatically maps the underlying keys based on the JWK Header kid.
|
|
func (m *KeyManager) Validate(ctx context.Context, tokenString string) (sig string, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
|
|
return "", fmt.Errorf("error getting jwk from token string: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Validate(ctx, tokenString)
|
|
}
|
|
|
|
// Hash implements the fosite jwt.Signer interface.
|
|
func (m *KeyManager) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
|
|
return m.GetByKID(ctx, "").Strategy().Hash(ctx, in)
|
|
}
|
|
|
|
// Decode implements the fosite jwt.Signer interface and automatically maps the underlying keys based on the JWK Header kid.
|
|
func (m *KeyManager) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
|
|
return nil, fmt.Errorf("error getting jwk from token string: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Decode(ctx, tokenString)
|
|
}
|
|
|
|
// GetSignature implements the fosite jwt.Signer interface.
|
|
func (m *KeyManager) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
// GetSigningMethodLength implements the fosite jwt.Signer interface.
|
|
func (m *KeyManager) GetSigningMethodLength(ctx context.Context) (size int) {
|
|
return m.GetByKID(ctx, "").Strategy().GetSigningMethodLength(ctx)
|
|
}
|
|
|
|
// NewPublicJSONWebKeySetFromSchemaJWK creates a *jose.JSONWebKeySet from a slice of schema.JWK.
|
|
func NewPublicJSONWebKeySetFromSchemaJWK(sjwks []schema.JWK) (jwks *jose.JSONWebKeySet) {
|
|
n := len(sjwks)
|
|
|
|
if n == 0 {
|
|
return nil
|
|
}
|
|
|
|
keys := make([]jose.JSONWebKey, n)
|
|
|
|
for i := 0; i < n; i++ {
|
|
jwk := jose.JSONWebKey{
|
|
KeyID: sjwks[i].KeyID,
|
|
Algorithm: sjwks[i].Algorithm,
|
|
Use: sjwks[i].Use,
|
|
Certificates: sjwks[i].CertificateChain.Certificates(),
|
|
}
|
|
|
|
switch key := sjwks[i].Key.(type) {
|
|
case *rsa.PublicKey:
|
|
jwk.Key = key
|
|
case rsa.PublicKey:
|
|
jwk.Key = &key
|
|
case *rsa.PrivateKey:
|
|
jwk.Key = key.PublicKey
|
|
case *ecdsa.PublicKey:
|
|
jwk.Key = key
|
|
case ecdsa.PublicKey:
|
|
jwk.Key = &key
|
|
case *ecdsa.PrivateKey:
|
|
jwk.Key = key.PublicKey
|
|
}
|
|
|
|
keys[i] = jwk
|
|
}
|
|
|
|
return &jose.JSONWebKeySet{
|
|
Keys: keys,
|
|
}
|
|
}
|
|
|
|
// NewJWK creates a *JWK f rom a schema.JWK.
|
|
func NewJWK(s schema.JWK) (jwk *JWK) {
|
|
jwk = &JWK{
|
|
kid: s.KeyID,
|
|
use: s.Use,
|
|
alg: jwt.GetSigningMethod(s.Algorithm),
|
|
key: s.Key.(schema.CryptographicPrivateKey),
|
|
|
|
chain: s.CertificateChain,
|
|
thumbprint: s.CertificateChain.Thumbprint(crypto.SHA256),
|
|
thumbprintsha1: s.CertificateChain.Thumbprint(crypto.SHA1),
|
|
}
|
|
|
|
switch jwk.alg {
|
|
case jwt.SigningMethodRS256, jwt.SigningMethodPS256, jwt.SigningMethodES256:
|
|
jwk.hash = crypto.SHA256
|
|
case jwt.SigningMethodRS384, jwt.SigningMethodPS384, jwt.SigningMethodES384:
|
|
jwk.hash = crypto.SHA384
|
|
case jwt.SigningMethodRS512, jwt.SigningMethodPS512, jwt.SigningMethodES512:
|
|
jwk.hash = crypto.SHA512
|
|
default:
|
|
jwk.hash = crypto.SHA256
|
|
}
|
|
|
|
return jwk
|
|
}
|
|
|
|
// JWK is a representation layer over the *jose.JSONWebKey for convenience.
|
|
type JWK struct {
|
|
kid string
|
|
use string
|
|
alg jwt.SigningMethod
|
|
hash crypto.Hash
|
|
|
|
key schema.CryptographicPrivateKey
|
|
chain schema.X509CertificateChain
|
|
thumbprintsha1 []byte
|
|
thumbprint []byte
|
|
}
|
|
|
|
// GetSigningMethod returns the jwt.SigningMethod for this *JWK.
|
|
func (j *JWK) GetSigningMethod() jwt.SigningMethod {
|
|
return j.alg
|
|
}
|
|
|
|
// GetPrivateKey returns the Private Key for this *JWK.
|
|
func (j *JWK) GetPrivateKey(ctx context.Context) (any, error) {
|
|
return j.PrivateJWK(), nil
|
|
}
|
|
|
|
// KeyID returns the Key ID for this *JWK.
|
|
func (j *JWK) KeyID() string {
|
|
return j.kid
|
|
}
|
|
|
|
// DirectJWK directly returns the *JWK as a jose.JSONWebKey with the private key if appropriate.
|
|
func (j *JWK) DirectJWK() (jwk jose.JSONWebKey) {
|
|
return jose.JSONWebKey{
|
|
Key: j.key,
|
|
KeyID: j.kid,
|
|
Algorithm: j.alg.Alg(),
|
|
Use: j.use,
|
|
Certificates: j.chain.Certificates(),
|
|
CertificateThumbprintSHA1: j.thumbprintsha1,
|
|
CertificateThumbprintSHA256: j.thumbprint,
|
|
}
|
|
}
|
|
|
|
// PrivateJWK directly returns the *JWK as a *jose.JSONWebKey with the private key if appropriate.
|
|
func (j *JWK) PrivateJWK() (jwk *jose.JSONWebKey) {
|
|
value := j.DirectJWK()
|
|
|
|
return &value
|
|
}
|
|
|
|
// JWK directly returns the *JWK as a jose.JSONWebKey specifically without the private key.
|
|
func (j *JWK) JWK() (jwk jose.JSONWebKey) {
|
|
if jwk = j.DirectJWK(); jwk.IsPublic() {
|
|
return jwk
|
|
}
|
|
|
|
return jwk.Public()
|
|
}
|
|
|
|
// Strategy returns the fosite jwt.Signer.
|
|
func (j *JWK) Strategy() (strategy fjwt.Signer) {
|
|
return &Signer{
|
|
hash: j.hash,
|
|
alg: j.alg,
|
|
GetPrivateKey: j.GetPrivateKey,
|
|
}
|
|
}
|
|
|
|
// Signer is responsible for generating and validating JWT challenges.
|
|
type Signer struct {
|
|
hash crypto.Hash
|
|
alg jwt.SigningMethod
|
|
|
|
GetPrivateKey fjwt.GetPrivateKeyFunc
|
|
}
|
|
|
|
// GetPublicKey returns the PublicKey for this Signer.
|
|
func (j *Signer) GetPublicKey(ctx context.Context) (key crypto.PublicKey, err error) {
|
|
var k any
|
|
|
|
if k, err = j.GetPrivateKey(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t := k.(type) {
|
|
case *jose.JSONWebKey:
|
|
return t.Public().Key, nil
|
|
case jose.OpaqueSigner:
|
|
return t.Public().Key, nil
|
|
case schema.CryptographicPrivateKey:
|
|
return t.Public(), nil
|
|
default:
|
|
return nil, errors.New("invalid private key type")
|
|
}
|
|
}
|
|
|
|
// Generate generates a new authorize code or returns an error. set secret.
|
|
func (j *Signer) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
|
|
var key any
|
|
|
|
if key, err = j.GetPrivateKey(ctx); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
switch t := key.(type) {
|
|
case *jose.JSONWebKey:
|
|
return generateToken(claims, header, j.alg, t.Key)
|
|
case jose.JSONWebKey:
|
|
return generateToken(claims, header, j.alg, t.Key)
|
|
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
|
return generateToken(claims, header, j.alg, t)
|
|
case jose.OpaqueSigner:
|
|
switch tt := t.Public().Key.(type) {
|
|
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
|
return generateToken(claims, header, j.alg, t)
|
|
default:
|
|
return "", "", fmt.Errorf("unsupported private / public key pairs: %T, %T", t, tt)
|
|
}
|
|
default:
|
|
return "", "", fmt.Errorf("unsupported private key type: %T", t)
|
|
}
|
|
}
|
|
|
|
// Validate validates a token and returns its signature or an error if the token is not valid.
|
|
func (j *Signer) Validate(ctx context.Context, tokenString string) (sig string, err error) {
|
|
var (
|
|
key crypto.PublicKey
|
|
)
|
|
|
|
if key, err = j.GetPublicKey(ctx); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return validateToken(tokenString, key)
|
|
}
|
|
|
|
// Decode will decode a JWT token.
|
|
func (j *Signer) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
|
|
var (
|
|
key crypto.PublicKey
|
|
)
|
|
|
|
if key, err = j.GetPublicKey(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return decodeToken(tokenString, key)
|
|
}
|
|
|
|
// GetSignature will return the signature of a token.
|
|
func (j *Signer) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
// Hash will return a given hash based on the byte input or an error upon fail.
|
|
func (j *Signer) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
|
|
hash := j.hash.New()
|
|
|
|
if _, err = hash.Write(in); err != nil {
|
|
return []byte{}, errorsx.WithStack(err)
|
|
}
|
|
|
|
return hash.Sum([]byte{}), nil
|
|
}
|
|
|
|
// GetSigningMethodLength will return the length of the signing method.
|
|
func (j *Signer) GetSigningMethodLength(ctx context.Context) (size int) {
|
|
return j.hash.Size()
|
|
}
|
|
|
|
func generateToken(claims fjwt.MapClaims, header fjwt.Mapper, signingMethod jwt.SigningMethod, key any) (rawToken string, sig string, err error) {
|
|
if header == nil || claims == nil {
|
|
return "", "", errors.New("either claims or header is nil")
|
|
}
|
|
|
|
token := jwt.NewWithClaims(signingMethod, claims)
|
|
|
|
token.Header = assign(token.Header, header.ToMap())
|
|
|
|
if rawToken, err = token.SignedString(key); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
if sig, err = getTokenSignature(rawToken); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return rawToken, sig, nil
|
|
}
|
|
|
|
func decodeToken(tokenString string, key any) (token *fjwt.Token, err error) {
|
|
return fjwt.ParseWithClaims(tokenString, fjwt.MapClaims{}, func(*fjwt.Token) (any, error) {
|
|
return key, nil
|
|
})
|
|
}
|
|
|
|
func validateToken(tokenString string, key any) (sig string, err error) {
|
|
if _, err = decodeToken(tokenString, key); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
func getTokenSignature(tokenString string) (sig string, err error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
|
|
if len(parts) != 3 {
|
|
return "", errors.New("header, body and signature must all be set")
|
|
}
|
|
|
|
return parts[2], nil
|
|
}
|
|
|
|
func assign(a, b map[string]any) map[string]any {
|
|
for k, w := range b {
|
|
if _, ok := a[k]; ok {
|
|
continue
|
|
}
|
|
|
|
a[k] = w
|
|
}
|
|
|
|
return a
|
|
}
|