204 lines
5.9 KiB
Go
204 lines
5.9 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/ory/fosite/token/jwt"
|
|
"gopkg.in/square/go-jose.v2"
|
|
|
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// NewKeyManagerWithConfiguration when provided a schema.OpenIDConnectConfiguration creates a new KeyManager and adds an
|
|
// initial key to the manager.
|
|
func NewKeyManagerWithConfiguration(configuration *schema.OpenIDConnectConfiguration) (manager *KeyManager, err error) {
|
|
manager = NewKeyManager()
|
|
|
|
_, _, err = manager.AddActivePrivateKeyData(configuration.IssuerPrivateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return manager, nil
|
|
}
|
|
|
|
// NewKeyManager creates a new empty KeyManager.
|
|
func NewKeyManager() (manager *KeyManager) {
|
|
manager = new(KeyManager)
|
|
manager.keys = map[string]*rsa.PrivateKey{}
|
|
manager.keySet = new(jose.JSONWebKeySet)
|
|
|
|
return manager
|
|
}
|
|
|
|
// Strategy returns the RS256JWTStrategy.
|
|
func (m KeyManager) Strategy() (strategy *RS256JWTStrategy) {
|
|
return m.strategy
|
|
}
|
|
|
|
// GetKeySet returns the joseJSONWebKeySet containing the rsa.PublicKey types.
|
|
func (m KeyManager) GetKeySet() (keySet *jose.JSONWebKeySet) {
|
|
return m.keySet
|
|
}
|
|
|
|
// GetActiveWebKey obtains the currently active jose.JSONWebKey.
|
|
func (m KeyManager) GetActiveWebKey() (webKey *jose.JSONWebKey, err error) {
|
|
webKeys := m.keySet.Key(m.activeKeyID)
|
|
if len(webKeys) == 1 {
|
|
return &webKeys[0], nil
|
|
}
|
|
|
|
if len(webKeys) == 0 {
|
|
return nil, errors.New("could not find a key with the active key id")
|
|
}
|
|
|
|
return &webKeys[0], errors.New("multiple keys with the same key id")
|
|
}
|
|
|
|
// GetActiveKeyID returns the key id of the currently active key.
|
|
func (m KeyManager) GetActiveKeyID() (keyID string) {
|
|
return m.activeKeyID
|
|
}
|
|
|
|
// GetActiveKey returns the rsa.PublicKey of the currently active key.
|
|
func (m KeyManager) GetActiveKey() (key *rsa.PublicKey, err error) {
|
|
if key, ok := m.keys[m.activeKeyID]; ok {
|
|
return &key.PublicKey, nil
|
|
}
|
|
|
|
return nil, errors.New("failed to retrieve active public key")
|
|
}
|
|
|
|
// GetActivePrivateKey returns the rsa.PrivateKey of the currently active key.
|
|
func (m KeyManager) GetActivePrivateKey() (key *rsa.PrivateKey, err error) {
|
|
if key, ok := m.keys[m.activeKeyID]; ok {
|
|
return key, nil
|
|
}
|
|
|
|
return nil, errors.New("failed to retrieve active private key")
|
|
}
|
|
|
|
// AddActivePrivateKeyData adds a rsa.PublicKey given the key in the PEM string format, then sets it to the active key.
|
|
func (m *KeyManager) AddActivePrivateKeyData(data string) (key *rsa.PrivateKey, webKey *jose.JSONWebKey, err error) {
|
|
ikey, err := utils.ParseX509FromPEM([]byte(data))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
var ok bool
|
|
|
|
if key, ok = ikey.(*rsa.PrivateKey); !ok {
|
|
return nil, nil, errors.New("key must be an RSA private key")
|
|
}
|
|
|
|
webKey, err = m.AddActivePrivateKey(key)
|
|
|
|
return key, webKey, err
|
|
}
|
|
|
|
// AddActivePrivateKey adds a rsa.PublicKey, then sets it to the active key.
|
|
func (m *KeyManager) AddActivePrivateKey(key *rsa.PrivateKey) (webKey *jose.JSONWebKey, err error) {
|
|
wk := jose.JSONWebKey{
|
|
Key: &key.PublicKey,
|
|
Algorithm: "RS256",
|
|
Use: "sig",
|
|
}
|
|
|
|
keyID, err := wk.Thumbprint(crypto.SHA1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
strKeyID := strings.ToLower(fmt.Sprintf("%x", keyID))
|
|
if len(strKeyID) >= 7 {
|
|
// Shorten the key if it's greater than 7 to a length of exactly 7.
|
|
strKeyID = strKeyID[0:6]
|
|
}
|
|
|
|
if _, ok := m.keys[strKeyID]; ok {
|
|
return nil, fmt.Errorf("key id %s already exists", strKeyID)
|
|
}
|
|
|
|
// TODO: Add Mutex here when implementing key rotation.
|
|
wk.KeyID = strKeyID
|
|
m.keySet.Keys = append(m.keySet.Keys, wk)
|
|
m.keys[strKeyID] = key
|
|
m.activeKeyID = strKeyID
|
|
|
|
m.strategy, err = NewRS256JWTStrategy(wk.KeyID, key)
|
|
if err != nil {
|
|
return &wk, err
|
|
}
|
|
|
|
return &wk, nil
|
|
}
|
|
|
|
// NewRS256JWTStrategy returns a new RS256JWTStrategy.
|
|
func NewRS256JWTStrategy(id string, key *rsa.PrivateKey) (strategy *RS256JWTStrategy, err error) {
|
|
strategy = new(RS256JWTStrategy)
|
|
strategy.JWTStrategy = new(jwt.RS256JWTStrategy)
|
|
|
|
strategy.SetKey(id, key)
|
|
|
|
return strategy, nil
|
|
}
|
|
|
|
// RS256JWTStrategy is a decorator struct for the fosite RS256JWTStrategy.
|
|
type RS256JWTStrategy struct {
|
|
JWTStrategy *jwt.RS256JWTStrategy
|
|
|
|
keyID string
|
|
}
|
|
|
|
// KeyID returns the key id.
|
|
func (s RS256JWTStrategy) KeyID() (id string) {
|
|
return s.keyID
|
|
}
|
|
|
|
// SetKey sets the provided key id and key as the active key (this is what triggers fosite to use it).
|
|
func (s *RS256JWTStrategy) SetKey(id string, key *rsa.PrivateKey) {
|
|
s.keyID = id
|
|
s.JWTStrategy.PrivateKey = key
|
|
}
|
|
|
|
// Hash is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
|
|
return s.JWTStrategy.Hash(ctx, in)
|
|
}
|
|
|
|
// GetSigningMethodLength is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) GetSigningMethodLength() int {
|
|
return s.JWTStrategy.GetSigningMethodLength()
|
|
}
|
|
|
|
// GetSignature is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
|
|
return s.JWTStrategy.GetSignature(ctx, token)
|
|
}
|
|
|
|
// Generate is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) Generate(ctx context.Context, claims jwt.MapClaims, header jwt.Mapper) (string, string, error) {
|
|
return s.JWTStrategy.Generate(ctx, claims, header)
|
|
}
|
|
|
|
// Validate is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
|
|
return s.JWTStrategy.Validate(ctx, token)
|
|
}
|
|
|
|
// Decode is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt.Token, error) {
|
|
return s.JWTStrategy.Decode(ctx, token)
|
|
}
|
|
|
|
// GetPublicKeyID is a decorator func for the underlying fosite RS256JWTStrategy.
|
|
func (s *RS256JWTStrategy) GetPublicKeyID(_ context.Context) (string, error) {
|
|
return s.keyID, nil
|
|
}
|