401 lines
9.7 KiB
Go
401 lines
9.7 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
fjwt "github.com/ory/fosite/token/jwt"
|
|
"github.com/ory/x/errorsx"
|
|
|
|
"gopkg.in/square/go-jose.v2"
|
|
|
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
|
)
|
|
|
|
// NewKeyManager news up a KeyManager.
|
|
func NewKeyManager(config *schema.OpenIDConnectConfiguration) (manager *KeyManager) {
|
|
manager = &KeyManager{
|
|
kids: map[string]*JWK{},
|
|
algs: map[string]*JWK{},
|
|
}
|
|
|
|
for _, sjwk := range config.IssuerJWKS {
|
|
jwk := NewJWK(sjwk)
|
|
|
|
manager.kids[sjwk.KeyID] = jwk
|
|
manager.algs[jwk.alg.Alg()] = jwk
|
|
|
|
if jwk.kid == config.Discovery.DefaultKeyID {
|
|
manager.kid = jwk.kid
|
|
}
|
|
}
|
|
|
|
return manager
|
|
}
|
|
|
|
// The KeyManager type handles JWKs and signing operations.
|
|
type KeyManager struct {
|
|
kid string
|
|
kids map[string]*JWK
|
|
algs map[string]*JWK
|
|
}
|
|
|
|
func (m *KeyManager) GetKIDFromAlgStrict(ctx context.Context, alg string) (kid string, err error) {
|
|
if jwks, ok := m.algs[alg]; ok {
|
|
return jwks.kid, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("alg not found")
|
|
}
|
|
|
|
func (m *KeyManager) GetKIDFromAlg(ctx context.Context, alg string) string {
|
|
if jwks, ok := m.algs[alg]; ok {
|
|
return jwks.kid
|
|
}
|
|
|
|
return m.kid
|
|
}
|
|
|
|
func (m *KeyManager) GetByAlg(ctx context.Context, alg string) *JWK {
|
|
if jwk, ok := m.algs[alg]; ok {
|
|
return jwk
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *KeyManager) GetByKID(ctx context.Context, kid string) *JWK {
|
|
if kid == "" {
|
|
return m.kids[m.kid]
|
|
}
|
|
|
|
if jwk, ok := m.kids[kid]; ok {
|
|
return jwk
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *KeyManager) GetByHeader(ctx context.Context, header fjwt.Mapper) (jwk *JWK, err error) {
|
|
var (
|
|
kid string
|
|
ok bool
|
|
)
|
|
|
|
if header == nil {
|
|
return nil, fmt.Errorf("jwt header was nil")
|
|
}
|
|
|
|
if kid, ok = header.Get(JWTHeaderKeyIdentifier).(string); !ok {
|
|
return nil, fmt.Errorf("jwt header did not have a kid")
|
|
}
|
|
|
|
if jwk, ok = m.kids[kid]; !ok {
|
|
return nil, fmt.Errorf("jwt header '%s' with value '%s' does not match a managed jwk", JWTHeaderKeyIdentifier, kid)
|
|
}
|
|
|
|
return jwk, nil
|
|
}
|
|
|
|
func (m *KeyManager) GetByTokenString(ctx context.Context, tokenString string) (jwk *JWK, err error) {
|
|
var (
|
|
token *jwt.Token
|
|
)
|
|
|
|
if token, _, err = jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m.GetByHeader(ctx, &fjwt.Headers{Extra: token.Header})
|
|
}
|
|
|
|
func (m *KeyManager) Set(ctx context.Context) *jose.JSONWebKeySet {
|
|
keys := make([]jose.JSONWebKey, 0, len(m.kids))
|
|
|
|
for _, jwk := range m.kids {
|
|
keys = append(keys, jwk.JWK())
|
|
}
|
|
|
|
sort.Sort(SortedJSONWebKey(keys))
|
|
|
|
return &jose.JSONWebKeySet{
|
|
Keys: keys,
|
|
}
|
|
}
|
|
|
|
func (m *KeyManager) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByHeader(ctx, header); err != nil {
|
|
return "", "", fmt.Errorf("error getting jwk from header: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Generate(ctx, claims, header)
|
|
}
|
|
|
|
func (m *KeyManager) Validate(ctx context.Context, tokenString string) (sig string, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
|
|
return "", fmt.Errorf("error getting jwk from token string: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Validate(ctx, tokenString)
|
|
}
|
|
|
|
func (m *KeyManager) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
|
|
return m.GetByKID(ctx, "").Strategy().Hash(ctx, in)
|
|
}
|
|
|
|
func (m *KeyManager) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
|
|
var jwk *JWK
|
|
|
|
if jwk, err = m.GetByTokenString(ctx, tokenString); err != nil {
|
|
return nil, fmt.Errorf("error getting jwk from token string: %w", err)
|
|
}
|
|
|
|
return jwk.Strategy().Decode(ctx, tokenString)
|
|
}
|
|
|
|
func (m *KeyManager) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
func (m *KeyManager) GetSigningMethodLength(ctx context.Context) (size int) {
|
|
return m.GetByKID(ctx, "").Strategy().GetSigningMethodLength(ctx)
|
|
}
|
|
|
|
func NewJWK(s schema.JWK) (jwk *JWK) {
|
|
jwk = &JWK{
|
|
kid: s.KeyID,
|
|
use: s.Use,
|
|
alg: jwt.GetSigningMethod(s.Algorithm),
|
|
key: s.Key.(schema.CryptographicPrivateKey),
|
|
|
|
chain: s.CertificateChain,
|
|
thumbprint: s.CertificateChain.Thumbprint(crypto.SHA256),
|
|
thumbprintsha1: s.CertificateChain.Thumbprint(crypto.SHA1),
|
|
}
|
|
|
|
switch jwk.alg {
|
|
case jwt.SigningMethodRS256, jwt.SigningMethodPS256, jwt.SigningMethodES256:
|
|
jwk.hash = crypto.SHA256
|
|
case jwt.SigningMethodRS384, jwt.SigningMethodPS384, jwt.SigningMethodES384:
|
|
jwk.hash = crypto.SHA384
|
|
case jwt.SigningMethodRS512, jwt.SigningMethodPS512, jwt.SigningMethodES512:
|
|
jwk.hash = crypto.SHA512
|
|
default:
|
|
jwk.hash = crypto.SHA256
|
|
}
|
|
|
|
return jwk
|
|
}
|
|
|
|
type JWK struct {
|
|
kid string
|
|
use string
|
|
alg jwt.SigningMethod
|
|
hash crypto.Hash
|
|
|
|
key schema.CryptographicPrivateKey
|
|
chain schema.X509CertificateChain
|
|
thumbprintsha1 []byte
|
|
thumbprint []byte
|
|
}
|
|
|
|
func (j *JWK) GetPrivateKey(ctx context.Context) (any, error) {
|
|
return j.PrivateJWK(), nil
|
|
}
|
|
|
|
func (j *JWK) KeyID() string {
|
|
return j.kid
|
|
}
|
|
|
|
func (j *JWK) PrivateJWK() (jwk *jose.JSONWebKey) {
|
|
return &jose.JSONWebKey{
|
|
Key: j.key,
|
|
KeyID: j.kid,
|
|
Algorithm: j.alg.Alg(),
|
|
Use: j.use,
|
|
Certificates: j.chain.Certificates(),
|
|
CertificateThumbprintSHA1: j.thumbprintsha1,
|
|
CertificateThumbprintSHA256: j.thumbprint,
|
|
}
|
|
}
|
|
|
|
func (j *JWK) JWK() (jwk jose.JSONWebKey) {
|
|
return j.PrivateJWK().Public()
|
|
}
|
|
|
|
func (j *JWK) Strategy() (strategy fjwt.Signer) {
|
|
return &Signer{
|
|
hash: j.hash,
|
|
alg: j.alg,
|
|
GetPrivateKey: j.GetPrivateKey,
|
|
}
|
|
}
|
|
|
|
// Signer is responsible for generating and validating JWT challenges.
|
|
type Signer struct {
|
|
hash crypto.Hash
|
|
alg jwt.SigningMethod
|
|
|
|
GetPrivateKey fjwt.GetPrivateKeyFunc
|
|
}
|
|
|
|
func (j *Signer) GetPublicKey(ctx context.Context) (key crypto.PublicKey, err error) {
|
|
var k any
|
|
|
|
if k, err = j.GetPrivateKey(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t := k.(type) {
|
|
case *jose.JSONWebKey:
|
|
return t.Public().Key, nil
|
|
case jose.OpaqueSigner:
|
|
return t.Public().Key, nil
|
|
case schema.CryptographicPrivateKey:
|
|
return t.Public(), nil
|
|
default:
|
|
return nil, errors.New("invalid private key type")
|
|
}
|
|
}
|
|
|
|
// Generate generates a new authorize code or returns an error. set secret.
|
|
func (j *Signer) Generate(ctx context.Context, claims fjwt.MapClaims, header fjwt.Mapper) (tokenString string, sig string, err error) {
|
|
var key any
|
|
|
|
if key, err = j.GetPrivateKey(ctx); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
switch t := key.(type) {
|
|
case *jose.JSONWebKey:
|
|
return generateToken(claims, header, j.alg, t.Key)
|
|
case jose.JSONWebKey:
|
|
return generateToken(claims, header, j.alg, t.Key)
|
|
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
|
return generateToken(claims, header, j.alg, t)
|
|
case jose.OpaqueSigner:
|
|
switch tt := t.Public().Key.(type) {
|
|
case *rsa.PrivateKey, *ecdsa.PrivateKey:
|
|
return generateToken(claims, header, j.alg, t)
|
|
default:
|
|
return "", "", fmt.Errorf("unsupported private / public key pairs: %T, %T", t, tt)
|
|
}
|
|
default:
|
|
return "", "", fmt.Errorf("unsupported private key type: %T", t)
|
|
}
|
|
}
|
|
|
|
// Validate validates a token and returns its signature or an error if the token is not valid.
|
|
func (j *Signer) Validate(ctx context.Context, tokenString string) (sig string, err error) {
|
|
var (
|
|
key crypto.PublicKey
|
|
)
|
|
|
|
if key, err = j.GetPublicKey(ctx); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return validateToken(tokenString, key)
|
|
}
|
|
|
|
// Decode will decode a JWT token.
|
|
func (j *Signer) Decode(ctx context.Context, tokenString string) (token *fjwt.Token, err error) {
|
|
var (
|
|
key crypto.PublicKey
|
|
)
|
|
|
|
if key, err = j.GetPublicKey(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return decodeToken(tokenString, key)
|
|
}
|
|
|
|
// GetSignature will return the signature of a token.
|
|
func (j *Signer) GetSignature(ctx context.Context, tokenString string) (sig string, err error) {
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
// Hash will return a given hash based on the byte input or an error upon fail.
|
|
func (j *Signer) Hash(ctx context.Context, in []byte) (sum []byte, err error) {
|
|
hash := j.hash.New()
|
|
|
|
if _, err = hash.Write(in); err != nil {
|
|
return []byte{}, errorsx.WithStack(err)
|
|
}
|
|
|
|
return hash.Sum([]byte{}), nil
|
|
}
|
|
|
|
// GetSigningMethodLength will return the length of the signing method.
|
|
func (j *Signer) GetSigningMethodLength(ctx context.Context) (size int) {
|
|
return j.hash.Size()
|
|
}
|
|
|
|
func generateToken(claims fjwt.MapClaims, header fjwt.Mapper, signingMethod jwt.SigningMethod, key any) (rawToken string, sig string, err error) {
|
|
if header == nil || claims == nil {
|
|
return "", "", errors.New("either claims or header is nil")
|
|
}
|
|
|
|
token := jwt.NewWithClaims(signingMethod, claims)
|
|
|
|
token.Header = assign(token.Header, header.ToMap())
|
|
|
|
if rawToken, err = token.SignedString(key); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
if sig, err = getTokenSignature(rawToken); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return rawToken, sig, nil
|
|
}
|
|
|
|
func decodeToken(tokenString string, key any) (token *fjwt.Token, err error) {
|
|
return fjwt.ParseWithClaims(tokenString, fjwt.MapClaims{}, func(*fjwt.Token) (any, error) {
|
|
return key, nil
|
|
})
|
|
}
|
|
|
|
func validateToken(tokenString string, key any) (sig string, err error) {
|
|
if _, err = decodeToken(tokenString, key); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return getTokenSignature(tokenString)
|
|
}
|
|
|
|
func getTokenSignature(tokenString string) (sig string, err error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
|
|
if len(parts) != 3 {
|
|
return "", errors.New("header, body and signature must all be set")
|
|
}
|
|
|
|
return parts[2], nil
|
|
}
|
|
|
|
func assign(a, b map[string]any) map[string]any {
|
|
for k, w := range b {
|
|
if _, ok := a[k]; ok {
|
|
continue
|
|
}
|
|
|
|
a[k] = w
|
|
}
|
|
|
|
return a
|
|
}
|