feat: implement mutual tls in the web server (#3065)

Mutual TLS helps prevent untrusted clients communicating with services like Authelia. This can be utilized to reduce the attack surface.

Fixes #3041
pull/3113/head
Clément Michaud 2022-04-05 01:57:47 +02:00 committed by GitHub
parent a2eb0316c8
commit 3ca438e3d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 891 additions and 172 deletions

View File

@ -19,6 +19,9 @@ type HostEntry struct {
} }
var hostEntries = []HostEntry{ var hostEntries = []HostEntry{
// For unit tests.
{Domain: "local.example.com", IP: "127.0.0.1"},
// For authelia backend. // For authelia backend.
{Domain: "authelia.example.com", IP: "192.168.240.50"}, {Domain: "authelia.example.com", IP: "192.168.240.50"},

View File

@ -71,6 +71,9 @@ server:
## The path to the DER base64/PEM format public certificate. ## The path to the DER base64/PEM format public certificate.
certificate: "" certificate: ""
## The list of certificates for client authentication.
client_certificates: []
## Server headers configuration/customization. ## Server headers configuration/customization.
headers: headers:

View File

@ -24,6 +24,7 @@ server:
tls: tls:
key: "" key: ""
certificate: "" certificate: ""
client_certificates: []
headers: headers:
csp_template: "" csp_template: ""
``` ```
@ -213,6 +214,19 @@ required: situational
The path to the public certificate for TLS connections. Must be in DER base64/PEM format. The path to the public certificate for TLS connections. Must be in DER base64/PEM format.
#### client_certificates
<div markdown="1">
type: list(string)
{: .label .label-config .label-purple }
default: []
{: .label .label-config .label-blue }
required: no
{: .label .label-config .label-yellow }
</div>
The list of file paths to certificates used for authenticating clients. Those certificates can be root
or intermediate certificates. If no item is provided mutual TLS is disabled.
### headers ### headers

View File

@ -189,6 +189,19 @@ Authelia protects your users against open redirect attacks by always checking if
to a subdomain of the domain protected by Authelia. This prevents phishing campaigns tricking users into visiting to a subdomain of the domain protected by Authelia. This prevents phishing campaigns tricking users into visiting
infected websites leveraging legit links. infected websites leveraging legit links.
## Mutual TLS
For the best security protection, configuration with TLS is highly recommended. TLS is used to secure the connection between
the proxies and Authelia instances meaning that an attacker on the network cannot perform a man-in-the-middle attack on those
connections. However, an attacker on the network can still impersonate proxies but this can be prevented by configuring mutual
TLS.
Mutual TLS brings mutual authentication between Authelia and the proxies. Any other party attempting to contact Authelia
would not even be able to create a TCP connection. This measure is recommended in all cases except if you already configured
some kind of ACLs specifically allowing the communication between proxies and Authelia instances like in a service mesh or
some kind of network overlay.
To configure mutual TLS, please refer to [this document](../configuration/server.md#client_certificates)
## Additional security ## Additional security
### Reset Password ### Reset Password

View File

@ -1,22 +1,15 @@
package commands package commands
import ( import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt" "fmt"
"log" "log"
"math/big"
"net"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
"github.com/authelia/authelia/v4/internal/utils"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -113,14 +106,13 @@ func cmdCertificatesGenerateRun(cmd *cobra.Command, _ []string) {
} }
func cmdCertificatesGenerateRunExtended(hosts []string, ecdsaCurve, validFrom, certificateTargetDirectory string, ed25519Key, isCA bool, rsaBits int, validFor time.Duration) { func cmdCertificatesGenerateRunExtended(hosts []string, ecdsaCurve, validFrom, certificateTargetDirectory string, ed25519Key, isCA bool, rsaBits int, validFor time.Duration) {
priv, err := getPrivateKey(ecdsaCurve, ed25519Key, rsaBits) certPath := filepath.Join(certificateTargetDirectory, "cert.pem")
keyPath := filepath.Join(certificateTargetDirectory, "key.pem")
if err != nil { var (
fmt.Printf("Failed to generate private key: %v\n", err) notBefore time.Time
os.Exit(1) err error
} )
var notBefore time.Time
switch len(validFrom) { switch len(validFrom) {
case 0: case 0:
@ -128,122 +120,47 @@ func cmdCertificatesGenerateRunExtended(hosts []string, ecdsaCurve, validFrom, c
default: default:
notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom) notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom)
if err != nil { if err != nil {
fmt.Printf("Failed to parse start date: %v\n", err) log.Fatalf("Failed to parse start date: %v", err)
os.Exit(1)
} }
} }
notAfter := notBefore.Add(validFor) var privateKeyBuilder utils.PrivateKeyBuilder
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
fmt.Printf("Failed to generate serial number: %v\n", err)
os.Exit(1)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
if isCA {
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
}
certPath := filepath.Join(certificateTargetDirectory, "cert.pem")
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv)
if err != nil {
fmt.Printf("Failed to create certificate: %v\n", err)
os.Exit(1)
}
writePEM(derBytes, "CERTIFICATE", certPath)
fmt.Printf("Certificate Public Key written to %s\n", certPath)
keyPath := filepath.Join(certificateTargetDirectory, "key.pem")
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
fmt.Printf("Failed to marshal private key: %v\n", err)
os.Exit(1)
}
writePEM(privBytes, "PRIVATE KEY", keyPath)
fmt.Printf("Certificate Private Key written to %s\n", keyPath)
}
func getPrivateKey(ecdsaCurve string, ed25519Key bool, rsaBits int) (priv interface{}, err error) {
switch ecdsaCurve { switch ecdsaCurve {
case "": case "":
if ed25519Key { if ed25519Key {
_, priv, err = ed25519.GenerateKey(rand.Reader) privateKeyBuilder = utils.Ed25519KeyBuilder{}
} else { } else {
priv, err = rsa.GenerateKey(rand.Reader, rsaBits) privateKeyBuilder = utils.RSAKeyBuilder{}.WithKeySize(rsaBits)
} }
case "P224": case "P224":
priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P224())
case "P256": case "P256":
priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
case "P384": case "P384":
priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P384())
case "P521": case "P521":
priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P521())
default: default:
err = fmt.Errorf("unrecognized elliptic curve: %q", ecdsaCurve) log.Fatalf("Failed to generate private key: unrecognized elliptic curve: \"%s\"", ecdsaCurve)
} }
return priv, err certBytes, keyBytes, err := utils.GenerateCertificate(privateKeyBuilder, hosts, notBefore, validFor, isCA)
}
func writePEM(bytes []byte, blockType, path string) {
keyOut, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
fmt.Printf("Failed to open %s for writing: %v\n", path, err) log.Fatal(err)
os.Exit(1)
} }
if err := pem.Encode(keyOut, &pem.Block{Type: blockType, Bytes: bytes}); err != nil { err = os.WriteFile(certPath, certBytes, 0600)
fmt.Printf("Failed to write data to %s: %v\n", path, err) if err != nil {
os.Exit(1) log.Fatalf("failed to write %s for writing: %v", certPath, err)
} }
if err := keyOut.Close(); err != nil { fmt.Printf("Certificate written to %s\n", certPath)
fmt.Printf("Error closing %s: %v\n", path, err)
os.Exit(1) err = os.WriteFile(keyPath, keyBytes, 0600)
} if err != nil {
} log.Fatalf("failed to write %s for writing: %v", certPath, err)
func publicKey(priv interface{}) interface{} {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
} }
fmt.Printf("Private Key written to %s\n", keyPath)
} }

View File

@ -77,7 +77,9 @@ func cmdRootRun(_ *cobra.Command, _ []string) {
doStartupChecks(config, &providers) doStartupChecks(config, &providers)
server.Start(*config, providers) s, listener := server.CreateServer(*config, providers)
logger.Fatal(s.Serve(listener))
} }
func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) { func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) {

View File

@ -71,6 +71,9 @@ server:
## The path to the DER base64/PEM format public certificate. ## The path to the DER base64/PEM format public certificate.
certificate: "" certificate: ""
## The list of certificates for client authentication.
client_certificates: []
## Server headers configuration/customization. ## Server headers configuration/customization.
headers: headers:

View File

@ -20,6 +20,7 @@ type ServerConfiguration struct {
type ServerTLSConfiguration struct { type ServerTLSConfiguration struct {
Certificate string `koanf:"certificate"` Certificate string `koanf:"certificate"`
Key string `koanf:"key"` Key string `koanf:"key"`
ClientCertificates []string `koanf:"client_certificates"`
} }
// ServerHeadersConfiguration represents the customization of the http server headers. // ServerHeadersConfiguration represents the customization of the http server headers.

View File

@ -44,8 +44,6 @@ const (
testLDAPURL = "ldap://ldap" testLDAPURL = "ldap://ldap"
testLDAPUser = "user" testLDAPUser = "user"
testModeDisabled = "disable" testModeDisabled = "disable"
testTLSCert = "/tmp/cert.pem"
testTLSKey = "/tmp/key.pem"
testEncryptionKey = "a_not_so_secure_encryption_key" testEncryptionKey = "a_not_so_secure_encryption_key"
) )
@ -224,6 +222,10 @@ const (
const ( const (
errFmtServerTLSCert = "server: tls: option 'key' must also be accompanied by option 'certificate'" errFmtServerTLSCert = "server: tls: option 'key' must also be accompanied by option 'certificate'"
errFmtServerTLSKey = "server: tls: option 'certificate' must also be accompanied by option 'key'" errFmtServerTLSKey = "server: tls: option 'certificate' must also be accompanied by option 'key'"
errFmtServerTLSCertFileDoesNotExist = "server: tls: file path %s provided in 'certificate' does not exist"
errFmtServerTLSKeyFileDoesNotExist = "server: tls: file path %s provided in 'key' does not exist"
errFmtServerTLSClientAuthCertFileDoesNotExist = "server: tls: client_certificates: certificates: file path %s does not exist"
errFmtServerTLSClientAuthNoAuth = "server: tls: client authentication cannot be configured if no server certificate and key are provided"
errFmtServerPathNoForwardSlashes = "server: option 'path' must not contain any forward slashes" errFmtServerPathNoForwardSlashes = "server: option 'path' must not contain any forward slashes"
errFmtServerPathAlphaNum = "server: option 'path' must only contain alpha numeric characters" errFmtServerPathAlphaNum = "server: option 'path' must only contain alpha numeric characters"

View File

@ -9,6 +9,44 @@ import (
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// validateFileExists checks whether a file exist.
func validateFileExists(path string, validator *schema.StructValidator, errTemplate string) {
exist, err := utils.FileExists(path)
if err != nil {
validator.Push(fmt.Errorf("tls: unable to check if file %s exists: %s", path, err))
}
if !exist {
validator.Push(fmt.Errorf(errTemplate, path))
}
}
// ValidateServerTLS checks a server TLS configuration is correct.
func ValidateServerTLS(config *schema.Configuration, validator *schema.StructValidator) {
if config.Server.TLS.Key != "" && config.Server.TLS.Certificate == "" {
validator.Push(fmt.Errorf(errFmtServerTLSCert))
} else if config.Server.TLS.Key == "" && config.Server.TLS.Certificate != "" {
validator.Push(fmt.Errorf(errFmtServerTLSKey))
}
if config.Server.TLS.Key != "" {
validateFileExists(config.Server.TLS.Key, validator, errFmtServerTLSKeyFileDoesNotExist)
}
if config.Server.TLS.Certificate != "" {
validateFileExists(config.Server.TLS.Certificate, validator, errFmtServerTLSCertFileDoesNotExist)
}
if config.Server.TLS.Key == "" && config.Server.TLS.Certificate == "" &&
len(config.Server.TLS.ClientCertificates) > 0 {
validator.Push(fmt.Errorf(errFmtServerTLSClientAuthNoAuth))
}
for _, clientCertPath := range config.Server.TLS.ClientCertificates {
validateFileExists(clientCertPath, validator, errFmtServerTLSClientAuthCertFileDoesNotExist)
}
}
// ValidateServer checks a server configuration is correct. // ValidateServer checks a server configuration is correct.
func ValidateServer(config *schema.Configuration, validator *schema.StructValidator) { func ValidateServer(config *schema.Configuration, validator *schema.StructValidator) {
if config.Server.Host == "" { if config.Server.Host == "" {
@ -19,11 +57,7 @@ func ValidateServer(config *schema.Configuration, validator *schema.StructValida
config.Server.Port = schema.DefaultServerConfiguration.Port config.Server.Port = schema.DefaultServerConfiguration.Port
} }
if config.Server.TLS.Key != "" && config.Server.TLS.Certificate == "" { ValidateServerTLS(config, validator)
validator.Push(fmt.Errorf(errFmtServerTLSCert))
} else if config.Server.TLS.Key == "" && config.Server.TLS.Certificate != "" {
validator.Push(fmt.Errorf(errFmtServerTLSKey))
}
switch { switch {
case strings.Contains(config.Server.Path, "/"): case strings.Contains(config.Server.Path, "/"):

View File

@ -1,6 +1,7 @@
package validator package validator
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -9,6 +10,8 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
) )
const unexistingFilePath = "/tmp/unexisting_file"
func TestShouldSetDefaultServerValues(t *testing.T) { func TestShouldSetDefaultServerValues(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := &schema.Configuration{} config := &schema.Configuration{}
@ -119,33 +122,129 @@ func TestShouldValidateAndUpdateHost(t *testing.T) {
func TestShouldRaiseErrorWhenTLSCertWithoutKeyIsProvided(t *testing.T) { func TestShouldRaiseErrorWhenTLSCertWithoutKeyIsProvided(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultConfig() config := newDefaultConfig()
config.Server.TLS.Certificate = testTLSCert
file, err := os.CreateTemp("", "cert")
require.NoError(t, err)
defer os.Remove(file.Name())
config.Server.TLS.Certificate = file.Name()
ValidateServer(&config, validator) ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1) require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: option 'certificate' must also be accompanied by option 'key'") assert.EqualError(t, validator.Errors()[0], "server: tls: option 'certificate' must also be accompanied by option 'key'")
} }
func TestShouldRaiseErrorWhenTLSCertDoesNotExist(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultConfig()
file, err := os.CreateTemp("", "key")
require.NoError(t, err)
defer os.Remove(file.Name())
config.Server.TLS.Certificate = unexistingFilePath
config.Server.TLS.Key = file.Name()
ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: file path /tmp/unexisting_file provided in 'certificate' does not exist")
}
func TestShouldRaiseErrorWhenTLSKeyWithoutCertIsProvided(t *testing.T) { func TestShouldRaiseErrorWhenTLSKeyWithoutCertIsProvided(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultConfig() config := newDefaultConfig()
config.Server.TLS.Key = testTLSKey
file, err := os.CreateTemp("", "key")
require.NoError(t, err)
defer os.Remove(file.Name())
config.Server.TLS.Key = file.Name()
ValidateServer(&config, validator) ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1) require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: option 'key' must also be accompanied by option 'certificate'") assert.EqualError(t, validator.Errors()[0], "server: tls: option 'key' must also be accompanied by option 'certificate'")
} }
func TestShouldRaiseErrorWhenTLSKeyDoesNotExist(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultConfig()
file, err := os.CreateTemp("", "key")
require.NoError(t, err)
defer os.Remove(file.Name())
config.Server.TLS.Key = unexistingFilePath
config.Server.TLS.Certificate = file.Name()
ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: file path /tmp/unexisting_file provided in 'key' does not exist")
}
func TestShouldNotRaiseErrorWhenBothTLSCertificateAndKeyAreProvided(t *testing.T) { func TestShouldNotRaiseErrorWhenBothTLSCertificateAndKeyAreProvided(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultConfig() config := newDefaultConfig()
config.Server.TLS.Certificate = testTLSCert
config.Server.TLS.Key = testTLSKey certFile, err := os.CreateTemp("", "cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
keyFile, err := os.CreateTemp("", "key")
require.NoError(t, err)
defer os.Remove(keyFile.Name())
config.Server.TLS.Certificate = certFile.Name()
config.Server.TLS.Key = keyFile.Name()
ValidateServer(&config, validator) ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 0) require.Len(t, validator.Errors(), 0)
} }
func TestShouldRaiseErrorWhenTLSClientCertificateDoesNotExist(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultConfig()
certFile, err := os.CreateTemp("", "cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
keyFile, err := os.CreateTemp("", "key")
require.NoError(t, err)
defer os.Remove(keyFile.Name())
config.Server.TLS.Certificate = certFile.Name()
config.Server.TLS.Key = keyFile.Name()
config.Server.TLS.ClientCertificates = []string{"/tmp/unexisting"}
ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: client_certificates: certificates: file path /tmp/unexisting does not exist")
}
func TestShouldRaiseErrorWhenTLSClientAuthIsDefinedButNotServerCertificate(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultConfig()
certFile, err := os.CreateTemp("", "cert")
require.NoError(t, err)
defer os.Remove(certFile.Name())
config.Server.TLS.ClientCertificates = []string{certFile.Name()}
ValidateServer(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "server: tls: client authentication cannot be configured if no server certificate and key are provided")
}
func TestShouldNotUpdateConfig(t *testing.T) { func TestShouldNotUpdateConfig(t *testing.T) {
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := newDefaultConfig() config := newDefaultConfig()

View File

@ -40,6 +40,9 @@ var (
} }
) )
const schemeHTTP = "http"
const schemeHTTPS = "https"
const ( const (
dev = "dev" dev = "dev"
f = "false" f = "false"

View File

@ -1,11 +1,15 @@
package server package server
import ( import (
"crypto/tls"
"crypto/x509"
"net" "net"
"os" "os"
"strconv" "strconv"
"time" "time"
"github.com/authelia/authelia/v4/internal/logging"
duoapi "github.com/duosecurity/duo_api_golang" duoapi "github.com/duosecurity/duo_api_golang"
"github.com/fasthttp/router" "github.com/fasthttp/router"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -15,7 +19,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/duo" "github.com/authelia/authelia/v4/internal/duo"
"github.com/authelia/authelia/v4/internal/handlers" "github.com/authelia/authelia/v4/internal/handlers"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
@ -178,10 +181,8 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
return handler return handler
} }
// Start Authelia's internal webserver with the given configuration and providers. // CreateServer Create Authelia's internal webserver with the given configuration and providers.
func Start(configuration schema.Configuration, providers middlewares.Providers) { func CreateServer(configuration schema.Configuration, providers middlewares.Providers) (*fasthttp.Server, net.Listener) {
logger := logging.Logger()
handler := registerRoutes(configuration, providers) handler := registerRoutes(configuration, providers)
server := &fasthttp.Server{ server := &fasthttp.Server{
@ -191,36 +192,66 @@ func Start(configuration schema.Configuration, providers middlewares.Providers)
ReadBufferSize: configuration.Server.ReadBufferSize, ReadBufferSize: configuration.Server.ReadBufferSize,
WriteBufferSize: configuration.Server.WriteBufferSize, WriteBufferSize: configuration.Server.WriteBufferSize,
} }
logger := logging.Logger()
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port)) address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
listener, err := net.Listen("tcp", address) var (
listener net.Listener
err error
connectionType string
connectionScheme string
)
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
connectionType, connectionScheme = "TLS", schemeHTTPS
err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)
if err != nil {
logger.Fatalf("unable to load certificate: %v", err)
}
if len(configuration.Server.TLS.ClientCertificates) > 0 {
caCertPool := x509.NewCertPool()
for _, path := range configuration.Server.TLS.ClientCertificates {
cert, err := os.ReadFile(path)
if err != nil {
logger.Fatalf("Cannot read client TLS certificate %s: %s", path, err)
}
caCertPool.AppendCertsFromPEM(cert)
}
// ClientCAs should never be nil, otherwise the system cert pool is used for client authentication
// but we don't want everybody on the Internet to be able to authenticate.
server.TLSConfig.ClientCAs = caCertPool
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone())
if err != nil { if err != nil {
logger.Fatalf("Error initializing listener: %s", err) logger.Fatalf("Error initializing listener: %s", err)
} }
} else {
connectionType, connectionScheme = "non-TLS", schemeHTTP
listener, err = net.Listen("tcp", address)
if err != nil {
logger.Fatalf("Error initializing listener: %s", err)
}
}
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" { if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, connectionScheme, configuration.Server.Host,
if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, "https", configuration.Server.Host, configuration.Server.Path, configuration.Server.Port); err != nil { configuration.Server.Path, configuration.Server.Port); err != nil {
logger.Fatalf("Could not configure healthcheck: %v", err) logger.Fatalf("Could not configure healthcheck: %v", err)
} }
actualAddress := listener.Addr().String()
if configuration.Server.Path == "" { if configuration.Server.Path == "" {
logger.Infof("Listening for TLS connections on '%s' path '/'", address) logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress)
} else { } else {
logger.Infof("Listening for TLS connections on '%s' paths '/' and '%s'", address, configuration.Server.Path) logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path)
} }
logger.Fatal(server.ServeTLS(listener, configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)) return server, listener
} else {
if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, "http", configuration.Server.Host, configuration.Server.Path, configuration.Server.Port); err != nil {
logger.Fatalf("Could not configure healthcheck: %v", err)
}
if configuration.Server.Path == "" {
logger.Infof("Listening for non-TLS connections on '%s' path '/'", address)
} else {
logger.Infof("Listening for non-TLS connections on '%s' paths '/' and '%s'", address, configuration.Server.Path)
}
logger.Fatal(server.Serve(listener))
}
} }

View File

@ -0,0 +1,379 @@
package server
import (
"crypto/elliptic"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/utils"
)
// TemporaryCertificate contains the FD of 2 temporary files containing the PEM format of the certificate and private key.
type TemporaryCertificate struct {
CertFile *os.File
KeyFile *os.File
Certificate *x509.Certificate
CertificatePEM []byte
KeyPEM []byte
}
func (tc TemporaryCertificate) TLSCertificate() (tls.Certificate, error) {
return tls.LoadX509KeyPair(tc.CertFile.Name(), tc.KeyFile.Name())
}
func (tc *TemporaryCertificate) Close() {
if tc.CertFile != nil {
tc.CertFile.Close()
}
if tc.KeyFile != nil {
tc.KeyFile.Close()
}
}
type CertificateContext struct {
Certificates []TemporaryCertificate
privateKeyBuilder utils.PrivateKeyBuilder
}
// NewCertificateContext instantiate a new certificate context used to easily generate certificates within tests.
func NewCertificateContext(privateKeyBuilder utils.PrivateKeyBuilder) (*CertificateContext, error) {
certificateContext := new(CertificateContext)
certificateContext.privateKeyBuilder = privateKeyBuilder
cert, err := certificateContext.GenerateCertificate()
if err != nil {
return nil, err
}
certificateContext.Certificates = []TemporaryCertificate{*cert}
return certificateContext, nil
}
// GenerateCertificate generate a new certificate in the context.
func (cc *CertificateContext) GenerateCertificate() (*TemporaryCertificate, error) {
certBytes, keyBytes, err := utils.GenerateCertificate(cc.privateKeyBuilder,
[]string{"authelia.com", "example.org", "local.example.com"},
time.Now(), 3*time.Hour, false)
if err != nil {
return nil, fmt.Errorf("unable to generate certificate: %v", err)
}
tmpCertificate := new(TemporaryCertificate)
certFile, err := os.CreateTemp("", "cert")
if err != nil {
return nil, fmt.Errorf("unable to create temp file for certificate: %v", err)
}
tmpCertificate.CertFile = certFile
tmpCertificate.CertificatePEM = certBytes
block, _ := pem.Decode(certBytes)
c, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %v", err)
}
tmpCertificate.Certificate = c
err = os.WriteFile(tmpCertificate.CertFile.Name(), certBytes, 0600)
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to write certificates in file: %v", err)
}
keyFile, err := os.CreateTemp("", "key")
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to create temp file for private key: %v", err)
}
tmpCertificate.KeyFile = keyFile
tmpCertificate.KeyPEM = keyBytes
err = os.WriteFile(tmpCertificate.KeyFile.Name(), keyBytes, 0600)
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to write private key in file: %v", err)
}
cc.Certificates = append(cc.Certificates, *tmpCertificate)
return tmpCertificate, nil
}
func (cc *CertificateContext) Close() {
for _, tc := range cc.Certificates {
tc.Close()
}
}
type TLSServerContext struct {
server *fasthttp.Server
port int
}
func NewTLSServerContext(configuration schema.Configuration) (*TLSServerContext, error) {
serverContext := new(TLSServerContext)
s, listener := CreateServer(configuration, middlewares.Providers{})
serverContext.server = s
go func() {
err := s.Serve(listener)
if err != nil {
logging.Logger().Fatal(err)
}
}()
addrSplit := strings.Split(listener.Addr().String(), ":")
if len(addrSplit) > 1 {
port, err := strconv.ParseInt(addrSplit[len(addrSplit)-1], 10, 32)
if err != nil {
return nil, fmt.Errorf("unable to parse port from address: %v", err)
}
serverContext.port = int(port)
}
return serverContext, nil
}
func (sc *TLSServerContext) Port() int {
return sc.port
}
func (sc *TLSServerContext) Close() error {
return sc.server.Shutdown()
}
func TestShouldRaiseErrorWhenClientDoesNotSkipVerify(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
fmt.Println(tlsServerContext.Port())
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d", tlsServerContext.Port()), nil)
require.NoError(t, err)
_, err = http.DefaultClient.Do(req)
require.Error(t, err)
require.Contains(t, err.Error(), "x509: certificate signed by unknown authority")
}
func TestShouldServeOverTLSWhenClientDoesSkipVerify(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in tests. Not used in production.
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}
func TestShouldServeOverTLSWhenClientHasProperRootCA(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
block, _ := pem.Decode(certificateContext.Certificates[0].CertificatePEM)
c, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(c)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}
func TestShouldRaiseWhenMutualTLSIsConfiguredAndClientIsNotAuthenticated(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
clientCert, err := certificateContext.GenerateCertificate()
require.NoError(t, err)
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
ClientCertificates: []string{clientCert.CertFile.Name()},
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificateContext.Certificates[0].Certificate)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
_, err = client.Do(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "remote error: tls: bad certificate")
}
func TestShouldServeProperlyWhenMutualTLSIsConfiguredAndClientIsAuthenticated(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
clientCert, err := certificateContext.GenerateCertificate()
require.NoError(t, err)
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
ClientCertificates: []string{clientCert.CertFile.Name()},
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificateContext.Certificates[0].Certificate)
cCert, err := certificateContext.Certificates[1].TLSCertificate()
require.NoError(t, err)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
Certificates: []tls.Certificate{cCert},
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}

View File

@ -46,7 +46,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error
90*time.Second, 90*time.Second,
dockerEnvironment, dockerEnvironment,
"authelia-backend", "authelia-backend",
[]string{"Listening for"}) []string{"Initializing server for"})
} }
func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error { func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error {

View File

@ -45,6 +45,10 @@ func (s *CLISuite) SetupTest() {
} }
func (s *CLISuite) TestShouldPrintBuildInformation() { func (s *CLISuite) TestShouldPrintBuildInformation() {
if os.Getenv("CI") == "false" {
s.T().Skip("Skipping testing in dev environment")
}
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "build-info"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "build-info"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Last Tag: ") s.Assert().Contains(output, "Last Tag: ")
@ -92,22 +96,22 @@ func (s *CLISuite) TestShouldHashPasswordSHA512() {
func (s *CLISuite) TestShouldGenerateCertificateRSA() { func (s *CLISuite) TestShouldGenerateCertificateRSA() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateRSAWithIPAddress() { func (s *CLISuite) TestShouldGenerateCertificateRSAWithIPAddress() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=127.0.0.1", "--dir=/tmp/"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=127.0.0.1", "--dir=/tmp/"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateRSAWithStartDate() { func (s *CLISuite) TestShouldGenerateCertificateRSAWithStartDate() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--start-date='Jan 1 15:04:05 2011'"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--start-date='Jan 1 15:04:05 2011'"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldFailGenerateCertificateRSAWithStartDate() { func (s *CLISuite) TestShouldFailGenerateCertificateRSAWithStartDate() {
@ -119,15 +123,15 @@ func (s *CLISuite) TestShouldFailGenerateCertificateRSAWithStartDate() {
func (s *CLISuite) TestShouldGenerateCertificateCA() { func (s *CLISuite) TestShouldGenerateCertificateCA() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ca"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ca"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateEd25519() { func (s *CLISuite) TestShouldGenerateCertificateEd25519() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ed25519"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ed25519"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldFailGenerateCertificateECDSA() { func (s *CLISuite) TestShouldFailGenerateCertificateECDSA() {
@ -139,29 +143,29 @@ func (s *CLISuite) TestShouldFailGenerateCertificateECDSA() {
func (s *CLISuite) TestShouldGenerateCertificateECDSAP224() { func (s *CLISuite) TestShouldGenerateCertificateECDSAP224() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P224"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P224"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateECDSAP256() { func (s *CLISuite) TestShouldGenerateCertificateECDSAP256() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P256"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P256"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateECDSAP384() { func (s *CLISuite) TestShouldGenerateCertificateECDSAP384() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P384"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P384"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestShouldGenerateCertificateECDSAP521() { func (s *CLISuite) TestShouldGenerateCertificateECDSAP521() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P521"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P521"})
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") s.Assert().Contains(output, "Certificate written to /tmp/cert.pem")
s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") s.Assert().Contains(output, "Private Key written to /tmp/key.pem")
} }
func (s *CLISuite) TestStorageShouldShowErrWithoutConfig() { func (s *CLISuite) TestStorageShouldShowErrWithoutConfig() {

View File

@ -1,17 +1,38 @@
package utils package utils
import ( import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt" "fmt"
"math/big"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
) )
// PEMBlockType represent an enum of the existing PEM block types.
type PEMBlockType int
const (
// Certificate block type.
Certificate PEMBlockType = iota
// PrivateKey block type.
PrivateKey
)
// NewTLSConfig generates a tls.Config from a schema.TLSConfig and a x509.CertPool. // NewTLSConfig generates a tls.Config from a schema.TLSConfig and a x509.CertPool.
func NewTLSConfig(config *schema.TLSConfig, defaultMinVersion uint16, certPool *x509.CertPool) (tlsConfig *tls.Config) { func NewTLSConfig(config *schema.TLSConfig, defaultMinVersion uint16, certPool *x509.CertPool) (tlsConfig *tls.Config) {
minVersion, err := TLSStringToTLSConfigVersion(config.MinimumVersion) minVersion, err := TLSStringToTLSConfigVersion(config.MinimumVersion)
@ -83,3 +104,150 @@ func TLSStringToTLSConfigVersion(input string) (version uint16, err error) {
return 0, ErrTLSVersionNotSupported return 0, ErrTLSVersionNotSupported
} }
// GenerateCertificate generate a certificate given a private key. RSA, Ed25519 and ECDSA are officially supported.
func GenerateCertificate(privateKeyBuilder PrivateKeyBuilder, hosts []string, validFrom time.Time, validFor time.Duration, isCA bool) ([]byte, []byte, error) {
privateKey, err := privateKeyBuilder.Build()
if err != nil {
return nil, nil, fmt.Errorf("unable to build private key: %w", err)
}
notBefore := validFrom
notAfter := validFrom.Add(validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate serial number: %v", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Acme Co"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
if isCA {
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
}
certDERBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to create certificate: %v", err)
}
certPEMBytes, err := ConvertDERToPEM(certDERBytes, Certificate)
if err != nil {
return nil, nil, fmt.Errorf("faile to convert certificate in DER format into PEM: %v", err)
}
keyDERBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal private key: %v", err)
}
keyPEMBytes, err := ConvertDERToPEM(keyDERBytes, PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("faile to convert certificate in DER format into PEM: %v", err)
}
return certPEMBytes, keyPEMBytes, nil
}
// ConvertDERToPEM convert certificate in DER format into PEM format.
func ConvertDERToPEM(der []byte, blockType PEMBlockType) ([]byte, error) {
var buf bytes.Buffer
var blockTypeStr string
switch blockType {
case Certificate:
blockTypeStr = "CERTIFICATE"
case PrivateKey:
blockTypeStr = "PRIVATE KEY"
default:
return nil, fmt.Errorf("unknown PEM block type %d", blockType)
}
if err := pem.Encode(&buf, &pem.Block{Type: blockTypeStr, Bytes: der}); err != nil {
return nil, fmt.Errorf("failed to encode DER data into PEM: %v", err)
}
return buf.Bytes(), nil
}
func publicKey(privateKey interface{}) interface{} {
switch k := privateKey.(type) {
case *rsa.PrivateKey:
return &k.PublicKey
case *ecdsa.PrivateKey:
return &k.PublicKey
case ed25519.PrivateKey:
return k.Public().(ed25519.PublicKey)
default:
return nil
}
}
// PrivateKeyBuilder interface for a private key builder.
type PrivateKeyBuilder interface {
Build() (interface{}, error)
}
// RSAKeyBuilder builder of RSA private key.
type RSAKeyBuilder struct {
keySizeInBits int
}
// WithKeySize configure the key size to use with RSA.
func (rkb RSAKeyBuilder) WithKeySize(bits int) RSAKeyBuilder {
rkb.keySizeInBits = bits
return rkb
}
// Build a RSA private key.
func (rkb RSAKeyBuilder) Build() (interface{}, error) {
return rsa.GenerateKey(rand.Reader, rkb.keySizeInBits)
}
// Ed25519KeyBuilder builder of Ed25519 private key.
type Ed25519KeyBuilder struct{}
// Build an Ed25519 private key.
func (ekb Ed25519KeyBuilder) Build() (interface{}, error) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
return priv, err
}
// ECDSAKeyBuilder builder of ECDSA private key.
type ECDSAKeyBuilder struct {
curve elliptic.Curve
}
// WithCurve configure the curve to use for the ECDSA private key.
func (ekb ECDSAKeyBuilder) WithCurve(curve elliptic.Curve) ECDSAKeyBuilder {
ekb.curve = curve
return ekb
}
// Build an ECDSA private key.
func (ekb ECDSAKeyBuilder) Build() (interface{}, error) {
return ecdsa.GenerateKey(ekb.curve, rand.Reader)
}

View File

@ -1,9 +1,11 @@
package utils package utils
import ( import (
"crypto/elliptic"
"crypto/tls" "crypto/tls"
"runtime" "runtime"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -124,3 +126,44 @@ func TestShouldReadCertsFromDirectoryButNotKeys(t *testing.T) {
assert.EqualError(t, errors[0], "could not import certificate key.pem") assert.EqualError(t, errors[0], "could not import certificate key.pem")
} }
func TestShouldGenerateCertificateAndPersistIt(t *testing.T) {
testCases := []struct {
Name string
PrivateKeyBuilder PrivateKeyBuilder
}{
{
Name: "P224",
PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P224()),
},
{
Name: "P256",
PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P256()),
},
{
Name: "P384",
PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P384()),
},
{
Name: "P521",
PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P521()),
},
{
Name: "Ed25519",
PrivateKeyBuilder: Ed25519KeyBuilder{},
},
{
Name: "RSA",
PrivateKeyBuilder: RSAKeyBuilder{keySizeInBits: 2048},
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
certBytes, keyBytes, err := GenerateCertificate(tc.PrivateKeyBuilder, []string{"authelia.com", "example.org"}, time.Now(), 3*time.Hour, false)
require.NoError(t, err)
assert.True(t, len(certBytes) > 0)
assert.True(t, len(keyBytes) > 0)
})
}
}