authelia/internal/oidc/keys.go

401 lines
9.7 KiB
Go
Raw Normal View History

package oidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"errors"
"fmt"
"sort"
"strings"
"github.com/golang-jwt/jwt/v4"
fjwt "github.com/ory/fosite/token/jwt"
"github.com/ory/x/errorsx"
"gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// NewKeyManager news up a KeyManager.
func NewKeyManager(config *schema.OpenIDConnectConfiguration) (manager *KeyManager) {
manager = &KeyManager{
kids: map[string]*JWK{},
algs: map[string]*JWK{},
}
for _, sjwk := range config.IssuerJWKS {
jwk := NewJWK(sjwk)
manager.kids[sjwk.KeyID] = jwk
manager.algs[jwk.alg.Alg()] = jwk
if jwk.kid == config.Discovery.DefaultKeyID {
manager.kid = jwk.kid
}
}
return manager
}
// The KeyManager type handles JWKs and signing operations.
type KeyManager struct {
kid string
kids map[string]*JWK
algs map[string]*JWK
}
func (m *KeyManager) GetKIDFromAlgStrict(ctx context.Context, alg string) (kid string, err error) {
if jwks, ok := m.algs[alg]; ok {
return jwks.kid, nil
}
return "", fmt.Errorf("alg not found")
}
func (m *KeyManager) GetKIDFromAlg(ctx context.Context, alg string) string {
if jwks, ok := m.algs[alg]; ok {
return jwks.kid
}
return m.kid
}
func (m *KeyManager) GetByAlg(ctx context.Context, alg string) *JWK {
if jwk, ok := m.algs[alg]; ok {
return jwk
}
return nil
}
func (m *KeyManager) GetByKID(ctx context.Context, kid string) *JWK {
if kid == "" {
return m.kids[m.kid]
}
if jwk, ok := m.kids[kid]; ok {
return jwk
}
return nil
}
func (m *KeyManager) GetByHeader(ctx context.Context, header fjwt.Mapper) (jwk *JWK, err error) {
var (
kid string
ok bool
)
if header == nil {
return nil, fmt.Errorf("jwt header was nil")
}
if kid, ok = header.Get(JWTHeaderKeyIdentifier).(string); !ok {
return nil, fmt.Errorf("jwt header did not have a kid")
}
if jwk, ok = m.kids[kid]; !ok {
return nil, fmt.Errorf("jwt header '%s' with value '%s' does not match a managed jwk", JWTHeaderKeyIdentifier, kid)
}
return jwk, nil
}
func (m *KeyManager) GetByTokenString(ctx context.Context, tokenString string) (jwk *JWK, err error) {
var (
token *jwt.Token
)
if token, _, err = jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{}); err != nil {
return nil, err
}
return m.GetByHeader(ctx, &fjwt.Headers{Extra: token.Header})
}
func (m *KeyManager) Set(ctx context.Context) *jose.JSONWebKeySet {
keys := make([]jose.JSONWebKey, 0, len(m.kids))
for _, jwk := range m.kids {
keys = append(keys, jwk.JWK())
}
sort.Sort(SortedJSONWebKey(keys))
return &jose.JSONWebKeySet{
Keys: keys,
}
}
func (m *KeyManager) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
var jwk *JWK
if jwk, err = m.GetByHeader(ctx, header); err != nil {
return "", "", fmt.Errorf("error getting jwk from header: %w", err)
}
return jwk.Strategy().Generate(ctx, claims, header)
}
func (m *KeyManager) Validate(ctx context.Context, tokenString string) (sig string, err error) {
var jwk *JWK
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
return "", fmt.Errorf("error getting jwk from token string: %w", err)
}
return jwk.Strategy().Validate(ctx, tokenString)
}
func (m *KeyManager) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
return m.GetByKID(ctx, "").Strategy().Hash(ctx, in)
}
func (m *KeyManager) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
var jwk *JWK
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
return nil, fmt.Errorf("error getting jwk from token string: %w", err)
}
return jwk.Strategy().Decode(ctx, tokenString)
}
func (m *KeyManager) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
return getTokenSignature(tokenString)
}
func (m *KeyManager) GetSigningMethodLength(ctx context.Context) (size int) {
return m.GetByKID(ctx, "").Strategy().GetSigningMethodLength(ctx)
}
func NewJWK(s schema.JWK) (jwk *JWK) {
jwk = &JWK{
kid: s.KeyID,
use: s.Use,
alg: jwt.GetSigningMethod(s.Algorithm),
key: s.Key.(schema.CryptographicPrivateKey),
chain: s.CertificateChain,
thumbprint: s.CertificateChain.Thumbprint(crypto.SHA256),
thumbprintsha1: s.CertificateChain.Thumbprint(crypto.SHA1),
}
switch jwk.alg {
case jwt.SigningMethodRS256, jwt.SigningMethodPS256, jwt.SigningMethodES256:
jwk.hash = crypto.SHA256
case jwt.SigningMethodRS384, jwt.SigningMethodPS384, jwt.SigningMethodES384:
jwk.hash = crypto.SHA384
case jwt.SigningMethodRS512, jwt.SigningMethodPS512, jwt.SigningMethodES512:
jwk.hash = crypto.SHA512
default:
jwk.hash = crypto.SHA256
}
return jwk
}
type JWK struct {
kid string
use string
alg jwt.SigningMethod
hash crypto.Hash
key schema.CryptographicPrivateKey
chain schema.X509CertificateChain
thumbprintsha1 []byte
thumbprint []byte
}
func (j *JWK) GetPrivateKey(ctx context.Context) (any, error) {
return j.PrivateJWK(), nil
}
func (j *JWK) KeyID() string {
return j.kid
}
func (j *JWK) PrivateJWK() (jwk *jose.JSONWebKey) {
return &jose.JSONWebKey{
Key: j.key,
KeyID: j.kid,
Algorithm: j.alg.Alg(),
Use: j.use,
Certificates: j.chain.Certificates(),
CertificateThumbprintSHA1: j.thumbprintsha1,
CertificateThumbprintSHA256: j.thumbprint,
}
}
func (j *JWK) JWK() (jwk jose.JSONWebKey) {
return j.PrivateJWK().Public()
}
func (j *JWK) Strategy() (strategy fjwt.Signer) {
return &Signer{
hash: j.hash,
alg: j.alg,
GetPrivateKey: j.GetPrivateKey,
}
}
// Signer is responsible for generating and validating JWT challenges.
type Signer struct {
hash crypto.Hash
alg jwt.SigningMethod
GetPrivateKey fjwt.GetPrivateKeyFunc
}
func (j *Signer) GetPublicKey(ctx context.Context) (key crypto.PublicKey, err error) {
var k any
if k, err = j.GetPrivateKey(ctx); err != nil {
return nil, err
}
switch t := k.(type) {
case *jose.JSONWebKey:
return t.Public().Key, nil
case jose.OpaqueSigner:
return t.Public().Key, nil
case schema.CryptographicPrivateKey:
return t.Public(), nil
default:
return nil, errors.New("invalid private key type")
}
}
// Generate generates a new authorize code or returns an error. set secret.
func (j *Signer) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
var key any
if key, err = j.GetPrivateKey(ctx); err != nil {
return "", "", err
}
switch t := key.(type) {
case *jose.JSONWebKey:
return generateToken(claims, header, j.alg, t.Key)
case jose.JSONWebKey:
return generateToken(claims, header, j.alg, t.Key)
case *rsa.PrivateKey, *ecdsa.PrivateKey:
return generateToken(claims, header, j.alg, t)
case jose.OpaqueSigner:
switch tt := t.Public().Key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey:
return generateToken(claims, header, j.alg, t)
default:
return "", "", fmt.Errorf("unsupported private / public key pairs: %T, %T", t, tt)
}
default:
return "", "", fmt.Errorf("unsupported private key type: %T", t)
}
}
// Validate validates a token and returns its signature or an error if the token is not valid.
func (j *Signer) Validate(ctx context.Context, tokenString string) (sig string, err error) {
var (
key crypto.PublicKey
)
if key, err = j.GetPublicKey(ctx); err != nil {
return "", err
}
return validateToken(tokenString, key)
}
// Decode will decode a JWT token.
func (j *Signer) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
var (
key crypto.PublicKey
)
if key, err = j.GetPublicKey(ctx); err != nil {
return nil, err
}
return decodeToken(tokenString, key)
}
// GetSignature will return the signature of a token.
func (j *Signer) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
return getTokenSignature(tokenString)
}
// Hash will return a given hash based on the byte input or an error upon fail.
func (j *Signer) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
hash := j.hash.New()
if _, err = hash.Write(in); err != nil {
return []byte{}, errorsx.WithStack(err)
}
return hash.Sum([]byte{}), nil
}
// GetSigningMethodLength will return the length of the signing method.
func (j *Signer) GetSigningMethodLength(ctx context.Context) (size int) {
return j.hash.Size()
}
func generateToken(claims fjwt.MapClaims, header fjwt.Mapper, signingMethod jwt.SigningMethod, key any) (rawToken string, sig string, err error) {
if header == nil || claims == nil {
return "", "", errors.New("either claims or header is nil")
}
token := jwt.NewWithClaims(signingMethod, claims)
token.Header = assign(token.Header, header.ToMap())
if rawToken, err = token.SignedString(key); err != nil {
return "", "", err
}
if sig, err = getTokenSignature(rawToken); err != nil {
return "", "", err
}
return rawToken, sig, nil
}
func decodeToken(tokenString string, key any) (token *fjwt.Token, err error) {
return fjwt.ParseWithClaims(tokenString, fjwt.MapClaims{}, func(*fjwt.Token) (any, error) {
return key, nil
})
}
func validateToken(tokenString string, key any) (sig string, err error) {
if _, err = decodeToken(tokenString, key); err != nil {
return "", err
}
return getTokenSignature(tokenString)
}
func getTokenSignature(tokenString string) (sig string, err error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return "", errors.New("header, body and signature must all be set")
}
return parts[2], nil
}
func assign(a, b map[string]any) map[string]any {
for k, w := range b {
if _, ok := a[k]; ok {
continue
}
a[k] = w
}
return a
}