authelia/internal/oidc/keys.go

198 lines
5.8 KiB
Go

package oidc
import (
"context"
"crypto"
"crypto/rsa"
"errors"
"fmt"
"strings"
"github.com/ory/fosite/token/jwt"
"gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewKeyManagerWithConfiguration when provided a schema.OpenIDConnectConfiguration creates a new KeyManager and adds an
// initial key to the manager.
func NewKeyManagerWithConfiguration(configuration *schema.OpenIDConnectConfiguration) (manager *KeyManager, err error) {
manager = NewKeyManager()
_, _, err = manager.AddActivePrivateKeyData(configuration.IssuerPrivateKey)
if err != nil {
return nil, err
}
return manager, nil
}
// NewKeyManager creates a new empty KeyManager.
func NewKeyManager() (manager *KeyManager) {
manager = new(KeyManager)
manager.keys = map[string]*rsa.PrivateKey{}
manager.keySet = new(jose.JSONWebKeySet)
return manager
}
// Strategy returns the RS256JWTStrategy.
func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) {
return m.strategy
}
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
return m.keySet
}
// GetActiveWebKey obtains the currently active jose.JSONWebKey.
func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
webKeys := m.keySet.Key(m.activeKeyID)
if len(webKeys) == 1 {
return &webKeys[0], nil
}
if len(webKeys) == 0 {
return nil, errors.New("could not find a key with the active key id")
}
return &webKeys[0], errors.New("multiple keys with the same key id")
}
// GetActiveKeyID returns the key id of the currently active key.
func (m KeyManager) GetActiveKeyID() (keyID string) {
return m.activeKeyID
}
// GetActiveKey returns the rsa.PublicKey of the currently active key.
func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
if key, ok := m.keys[m.activeKeyID]; ok {
return &key.PublicKey, nil
}
return nil, errors.New("failed to retrieve active public key")
}
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
if key, ok := m.keys[m.activeKeyID]; ok {
return key, nil
}
return nil, errors.New("failed to retrieve active private key")
}
// AddActivePrivateKeyData adds a rsa.PublicKey given the key in the PEM string format, then sets it to the active key.
func (m *KeyManager) AddActivePrivateKeyData(data string) (key *rsa.PrivateKey, webKey *jose.JSONWebKey, err error) {
key, err = utils.ParseRsaPrivateKeyFromPemStr(data)
if err != nil {
return nil, nil, err
}
webKey, err = m.AddActivePrivateKey(key)
return key, webKey, err
}
// AddActivePrivateKey adds a rsa.PublicKey, then sets it to the active key.
func (m *KeyManager) AddActivePrivateKey(key *rsa.PrivateKey) (webKey *jose.JSONWebKey, err error) {
wk := jose.JSONWebKey{
Key: &key.PublicKey,
Algorithm: "RS256",
Use: "sig",
}
keyID, err := wk.Thumbprint(crypto.SHA1)
if err != nil {
return nil, err
}
strKeyID := strings.ToLower(fmt.Sprintf("%x", keyID))
if len(strKeyID) >= 7 {
// Shorten the key if it's greater than 7 to a length of exactly 7.
strKeyID = strKeyID[0:6]
}
if _, ok := m.keys[strKeyID]; ok {
return nil, fmt.Errorf("key id %s already exists", strKeyID)
}
// TODO: Add Mutex here when implementing key rotation.
wk.KeyID = strKeyID
m.keySet.Keys = append(m.keySet.Keys, wk)
m.keys[strKeyID] = key
m.activeKeyID = strKeyID
m.strategy, err = NewRS256JWTStrategy(wk.KeyID, key)
if err != nil {
return &wk, err
}
return &wk, nil
}
// NewRS256JWTStrategy returns a new RS256JWTStrategy.
func NewRS256JWTStrategy(id string, key *rsa.PrivateKey) (strategy *RS256JWTStrategy, err error) {
strategy = new(RS256JWTStrategy)
strategy.JWTStrategy = new(jwt.RS256JWTStrategy)
strategy.SetKey(id, key)
return strategy, nil
}
// RS256JWTStrategy is a decorator struct for the fosite RS256JWTStrategy.
type RS256JWTStrategy struct {
JWTStrategy *jwt.RS256JWTStrategy
keyID string
}
// KeyID returns the key id.
func (s RS256JWTStrategy) KeyID() (id string) {
return s.keyID
}
// SetKey sets the provided key id and key as the active key (this is what triggers fosite to use it).
func (s *RS256JWTStrategy) SetKey(id string, key *rsa.PrivateKey) {
s.keyID = id
s.JWTStrategy.PrivateKey = key
}
// Hash is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
return s.JWTStrategy.Hash(ctx, in)
}
// GetSigningMethodLength is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) GetSigningMethodLength() int {
return s.JWTStrategy.GetSigningMethodLength()
}
// GetSignature is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
return s.JWTStrategy.GetSignature(ctx, token)
}
// Generate is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) Generate(ctx context.Context, claims jwt.MapClaims, header jwt.Mapper) (string, string, error) {
return s.JWTStrategy.Generate(ctx, claims, header)
}
// Validate is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
return s.JWTStrategy.Validate(ctx, token)
}
// Decode is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt.Token, error) {
return s.JWTStrategy.Decode(ctx, token)
}
// GetPublicKeyID is a decorator func for the underlying fosite RS256JWTStrategy.
func (s *RS256JWTStrategy) GetPublicKeyID(_ context.Context) (string, error) {
return s.keyID, nil
}