diff --git a/cmd/authelia-scripts/cmd_bootstrap.go b/cmd/authelia-scripts/cmd_bootstrap.go index 28797e194..a7dbcb4db 100644 --- a/cmd/authelia-scripts/cmd_bootstrap.go +++ b/cmd/authelia-scripts/cmd_bootstrap.go @@ -19,6 +19,9 @@ type HostEntry struct { } var hostEntries = []HostEntry{ + // For unit tests. + {Domain: "local.example.com", IP: "127.0.0.1"}, + // For authelia backend. {Domain: "authelia.example.com", IP: "192.168.240.50"}, diff --git a/config.template.yml b/config.template.yml index 042aeecdd..db21d7af5 100644 --- a/config.template.yml +++ b/config.template.yml @@ -71,6 +71,9 @@ server: ## The path to the DER base64/PEM format public certificate. certificate: "" + ## The list of certificates for client authentication. + client_certificates: [] + ## Server headers configuration/customization. headers: diff --git a/docs/configuration/server.md b/docs/configuration/server.md index 5f3d28e27..3a705a425 100644 --- a/docs/configuration/server.md +++ b/docs/configuration/server.md @@ -24,6 +24,7 @@ server: tls: key: "" certificate: "" + client_certificates: [] headers: csp_template: "" ``` @@ -213,6 +214,19 @@ required: situational The path to the public certificate for TLS connections. Must be in DER base64/PEM format. +#### client_certificates +
+type: list(string) +{: .label .label-config .label-purple } +default: [] +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-yellow } +
+ +The list of file paths to certificates used for authenticating clients. Those certificates can be root +or intermediate certificates. If no item is provided mutual TLS is disabled. + ### headers diff --git a/docs/security/measures.md b/docs/security/measures.md index 556357f96..12bd1c566 100644 --- a/docs/security/measures.md +++ b/docs/security/measures.md @@ -189,6 +189,19 @@ Authelia protects your users against open redirect attacks by always checking if to a subdomain of the domain protected by Authelia. This prevents phishing campaigns tricking users into visiting infected websites leveraging legit links. +## Mutual TLS + +For the best security protection, configuration with TLS is highly recommended. TLS is used to secure the connection between +the proxies and Authelia instances meaning that an attacker on the network cannot perform a man-in-the-middle attack on those +connections. However, an attacker on the network can still impersonate proxies but this can be prevented by configuring mutual +TLS. +Mutual TLS brings mutual authentication between Authelia and the proxies. Any other party attempting to contact Authelia +would not even be able to create a TCP connection. This measure is recommended in all cases except if you already configured +some kind of ACLs specifically allowing the communication between proxies and Authelia instances like in a service mesh or +some kind of network overlay. + +To configure mutual TLS, please refer to [this document](../configuration/server.md#client_certificates) + ## Additional security ### Reset Password diff --git a/internal/commands/certificates.go b/internal/commands/certificates.go index 6257aae61..b6f0bfc5b 100644 --- a/internal/commands/certificates.go +++ b/internal/commands/certificates.go @@ -1,22 +1,15 @@ package commands import ( - "crypto/ecdsa" - "crypto/ed25519" "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "fmt" "log" - "math/big" - "net" "os" "path/filepath" "time" + "github.com/authelia/authelia/v4/internal/utils" + "github.com/spf13/cobra" ) @@ -113,14 +106,13 @@ func cmdCertificatesGenerateRun(cmd *cobra.Command, _ []string) { } func cmdCertificatesGenerateRunExtended(hosts []string, ecdsaCurve, validFrom, certificateTargetDirectory string, ed25519Key, isCA bool, rsaBits int, validFor time.Duration) { - priv, err := getPrivateKey(ecdsaCurve, ed25519Key, rsaBits) + certPath := filepath.Join(certificateTargetDirectory, "cert.pem") + keyPath := filepath.Join(certificateTargetDirectory, "key.pem") - if err != nil { - fmt.Printf("Failed to generate private key: %v\n", err) - os.Exit(1) - } - - var notBefore time.Time + var ( + notBefore time.Time + err error + ) switch len(validFrom) { case 0: @@ -128,122 +120,47 @@ func cmdCertificatesGenerateRunExtended(hosts []string, ecdsaCurve, validFrom, c default: notBefore, err = time.Parse("Jan 2 15:04:05 2006", validFrom) if err != nil { - fmt.Printf("Failed to parse start date: %v\n", err) - os.Exit(1) + log.Fatalf("Failed to parse start date: %v", err) } } - notAfter := notBefore.Add(validFor) + var privateKeyBuilder utils.PrivateKeyBuilder - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - fmt.Printf("Failed to generate serial number: %v\n", err) - os.Exit(1) - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Acme Co"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - for _, h := range hosts { - if ip := net.ParseIP(h); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, h) - } - } - - if isCA { - template.IsCA = true - template.KeyUsage |= x509.KeyUsageCertSign - } - - certPath := filepath.Join(certificateTargetDirectory, "cert.pem") - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(priv), priv) - if err != nil { - fmt.Printf("Failed to create certificate: %v\n", err) - os.Exit(1) - } - - writePEM(derBytes, "CERTIFICATE", certPath) - - fmt.Printf("Certificate Public Key written to %s\n", certPath) - - keyPath := filepath.Join(certificateTargetDirectory, "key.pem") - - privBytes, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - fmt.Printf("Failed to marshal private key: %v\n", err) - os.Exit(1) - } - - writePEM(privBytes, "PRIVATE KEY", keyPath) - - fmt.Printf("Certificate Private Key written to %s\n", keyPath) -} - -func getPrivateKey(ecdsaCurve string, ed25519Key bool, rsaBits int) (priv interface{}, err error) { switch ecdsaCurve { case "": if ed25519Key { - _, priv, err = ed25519.GenerateKey(rand.Reader) + privateKeyBuilder = utils.Ed25519KeyBuilder{} } else { - priv, err = rsa.GenerateKey(rand.Reader, rsaBits) + privateKeyBuilder = utils.RSAKeyBuilder{}.WithKeySize(rsaBits) } case "P224": - priv, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P224()) case "P256": - priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) case "P384": - priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P384()) case "P521": - priv, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + privateKeyBuilder = utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P521()) default: - err = fmt.Errorf("unrecognized elliptic curve: %q", ecdsaCurve) + log.Fatalf("Failed to generate private key: unrecognized elliptic curve: \"%s\"", ecdsaCurve) } - return priv, err -} - -func writePEM(bytes []byte, blockType, path string) { - keyOut, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + certBytes, keyBytes, err := utils.GenerateCertificate(privateKeyBuilder, hosts, notBefore, validFor, isCA) if err != nil { - fmt.Printf("Failed to open %s for writing: %v\n", path, err) - os.Exit(1) + log.Fatal(err) } - if err := pem.Encode(keyOut, &pem.Block{Type: blockType, Bytes: bytes}); err != nil { - fmt.Printf("Failed to write data to %s: %v\n", path, err) - os.Exit(1) + err = os.WriteFile(certPath, certBytes, 0600) + if err != nil { + log.Fatalf("failed to write %s for writing: %v", certPath, err) } - if err := keyOut.Close(); err != nil { - fmt.Printf("Error closing %s: %v\n", path, err) - os.Exit(1) - } -} - -func publicKey(priv interface{}) interface{} { - switch k := priv.(type) { - case *rsa.PrivateKey: - return &k.PublicKey - case *ecdsa.PrivateKey: - return &k.PublicKey - case ed25519.PrivateKey: - return k.Public().(ed25519.PublicKey) - default: - return nil + fmt.Printf("Certificate written to %s\n", certPath) + + err = os.WriteFile(keyPath, keyBytes, 0600) + if err != nil { + log.Fatalf("failed to write %s for writing: %v", certPath, err) } + + fmt.Printf("Private Key written to %s\n", keyPath) } diff --git a/internal/commands/root.go b/internal/commands/root.go index 38ce814df..d64e99c2c 100644 --- a/internal/commands/root.go +++ b/internal/commands/root.go @@ -77,7 +77,9 @@ func cmdRootRun(_ *cobra.Command, _ []string) { doStartupChecks(config, &providers) - server.Start(*config, providers) + s, listener := server.CreateServer(*config, providers) + + logger.Fatal(s.Serve(listener)) } func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) { diff --git a/internal/configuration/config.template.yml b/internal/configuration/config.template.yml index 042aeecdd..db21d7af5 100644 --- a/internal/configuration/config.template.yml +++ b/internal/configuration/config.template.yml @@ -71,6 +71,9 @@ server: ## The path to the DER base64/PEM format public certificate. certificate: "" + ## The list of certificates for client authentication. + client_certificates: [] + ## Server headers configuration/customization. headers: diff --git a/internal/configuration/schema/server.go b/internal/configuration/schema/server.go index 528df6e54..77146a454 100644 --- a/internal/configuration/schema/server.go +++ b/internal/configuration/schema/server.go @@ -18,8 +18,9 @@ type ServerConfiguration struct { // ServerTLSConfiguration represents the configuration of the http servers TLS options. type ServerTLSConfiguration struct { - Certificate string `koanf:"certificate"` - Key string `koanf:"key"` + Certificate string `koanf:"certificate"` + Key string `koanf:"key"` + ClientCertificates []string `koanf:"client_certificates"` } // ServerHeadersConfiguration represents the customization of the http server headers. diff --git a/internal/configuration/validator/const.go b/internal/configuration/validator/const.go index 75a79486d..21a470609 100644 --- a/internal/configuration/validator/const.go +++ b/internal/configuration/validator/const.go @@ -44,8 +44,6 @@ const ( testLDAPURL = "ldap://ldap" testLDAPUser = "user" testModeDisabled = "disable" - testTLSCert = "/tmp/cert.pem" - testTLSKey = "/tmp/key.pem" testEncryptionKey = "a_not_so_secure_encryption_key" ) @@ -222,8 +220,12 @@ const ( // Server Error constants. const ( - errFmtServerTLSCert = "server: tls: option 'key' must also be accompanied by option 'certificate'" - errFmtServerTLSKey = "server: tls: option 'certificate' must also be accompanied by option 'key'" + errFmtServerTLSCert = "server: tls: option 'key' must also be accompanied by option 'certificate'" + errFmtServerTLSKey = "server: tls: option 'certificate' must also be accompanied by option 'key'" + errFmtServerTLSCertFileDoesNotExist = "server: tls: file path %s provided in 'certificate' does not exist" + errFmtServerTLSKeyFileDoesNotExist = "server: tls: file path %s provided in 'key' does not exist" + errFmtServerTLSClientAuthCertFileDoesNotExist = "server: tls: client_certificates: certificates: file path %s does not exist" + errFmtServerTLSClientAuthNoAuth = "server: tls: client authentication cannot be configured if no server certificate and key are provided" errFmtServerPathNoForwardSlashes = "server: option 'path' must not contain any forward slashes" errFmtServerPathAlphaNum = "server: option 'path' must only contain alpha numeric characters" diff --git a/internal/configuration/validator/server.go b/internal/configuration/validator/server.go index a958f0e13..06f26e051 100644 --- a/internal/configuration/validator/server.go +++ b/internal/configuration/validator/server.go @@ -9,6 +9,44 @@ import ( "github.com/authelia/authelia/v4/internal/utils" ) +// validateFileExists checks whether a file exist. +func validateFileExists(path string, validator *schema.StructValidator, errTemplate string) { + exist, err := utils.FileExists(path) + if err != nil { + validator.Push(fmt.Errorf("tls: unable to check if file %s exists: %s", path, err)) + } + + if !exist { + validator.Push(fmt.Errorf(errTemplate, path)) + } +} + +// ValidateServerTLS checks a server TLS configuration is correct. +func ValidateServerTLS(config *schema.Configuration, validator *schema.StructValidator) { + if config.Server.TLS.Key != "" && config.Server.TLS.Certificate == "" { + validator.Push(fmt.Errorf(errFmtServerTLSCert)) + } else if config.Server.TLS.Key == "" && config.Server.TLS.Certificate != "" { + validator.Push(fmt.Errorf(errFmtServerTLSKey)) + } + + if config.Server.TLS.Key != "" { + validateFileExists(config.Server.TLS.Key, validator, errFmtServerTLSKeyFileDoesNotExist) + } + + if config.Server.TLS.Certificate != "" { + validateFileExists(config.Server.TLS.Certificate, validator, errFmtServerTLSCertFileDoesNotExist) + } + + if config.Server.TLS.Key == "" && config.Server.TLS.Certificate == "" && + len(config.Server.TLS.ClientCertificates) > 0 { + validator.Push(fmt.Errorf(errFmtServerTLSClientAuthNoAuth)) + } + + for _, clientCertPath := range config.Server.TLS.ClientCertificates { + validateFileExists(clientCertPath, validator, errFmtServerTLSClientAuthCertFileDoesNotExist) + } +} + // ValidateServer checks a server configuration is correct. func ValidateServer(config *schema.Configuration, validator *schema.StructValidator) { if config.Server.Host == "" { @@ -19,11 +57,7 @@ func ValidateServer(config *schema.Configuration, validator *schema.StructValida config.Server.Port = schema.DefaultServerConfiguration.Port } - if config.Server.TLS.Key != "" && config.Server.TLS.Certificate == "" { - validator.Push(fmt.Errorf(errFmtServerTLSCert)) - } else if config.Server.TLS.Key == "" && config.Server.TLS.Certificate != "" { - validator.Push(fmt.Errorf(errFmtServerTLSKey)) - } + ValidateServerTLS(config, validator) switch { case strings.Contains(config.Server.Path, "/"): diff --git a/internal/configuration/validator/server_test.go b/internal/configuration/validator/server_test.go index e11e884a2..7349465cd 100644 --- a/internal/configuration/validator/server_test.go +++ b/internal/configuration/validator/server_test.go @@ -1,6 +1,7 @@ package validator import ( + "os" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,8 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" ) +const unexistingFilePath = "/tmp/unexisting_file" + func TestShouldSetDefaultServerValues(t *testing.T) { validator := schema.NewStructValidator() config := &schema.Configuration{} @@ -119,33 +122,129 @@ func TestShouldValidateAndUpdateHost(t *testing.T) { func TestShouldRaiseErrorWhenTLSCertWithoutKeyIsProvided(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultConfig() - config.Server.TLS.Certificate = testTLSCert + + file, err := os.CreateTemp("", "cert") + require.NoError(t, err) + + defer os.Remove(file.Name()) + + config.Server.TLS.Certificate = file.Name() ValidateServer(&config, validator) require.Len(t, validator.Errors(), 1) assert.EqualError(t, validator.Errors()[0], "server: tls: option 'certificate' must also be accompanied by option 'key'") } +func TestShouldRaiseErrorWhenTLSCertDoesNotExist(t *testing.T) { + validator := schema.NewStructValidator() + config := newDefaultConfig() + + file, err := os.CreateTemp("", "key") + require.NoError(t, err) + + defer os.Remove(file.Name()) + + config.Server.TLS.Certificate = unexistingFilePath + config.Server.TLS.Key = file.Name() + + ValidateServer(&config, validator) + require.Len(t, validator.Errors(), 1) + assert.EqualError(t, validator.Errors()[0], "server: tls: file path /tmp/unexisting_file provided in 'certificate' does not exist") +} + func TestShouldRaiseErrorWhenTLSKeyWithoutCertIsProvided(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultConfig() - config.Server.TLS.Key = testTLSKey + + file, err := os.CreateTemp("", "key") + require.NoError(t, err) + + defer os.Remove(file.Name()) + + config.Server.TLS.Key = file.Name() ValidateServer(&config, validator) require.Len(t, validator.Errors(), 1) assert.EqualError(t, validator.Errors()[0], "server: tls: option 'key' must also be accompanied by option 'certificate'") } +func TestShouldRaiseErrorWhenTLSKeyDoesNotExist(t *testing.T) { + validator := schema.NewStructValidator() + config := newDefaultConfig() + + file, err := os.CreateTemp("", "key") + require.NoError(t, err) + + defer os.Remove(file.Name()) + + config.Server.TLS.Key = unexistingFilePath + config.Server.TLS.Certificate = file.Name() + + ValidateServer(&config, validator) + require.Len(t, validator.Errors(), 1) + assert.EqualError(t, validator.Errors()[0], "server: tls: file path /tmp/unexisting_file provided in 'key' does not exist") +} + func TestShouldNotRaiseErrorWhenBothTLSCertificateAndKeyAreProvided(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultConfig() - config.Server.TLS.Certificate = testTLSCert - config.Server.TLS.Key = testTLSKey + + certFile, err := os.CreateTemp("", "cert") + require.NoError(t, err) + + defer os.Remove(certFile.Name()) + + keyFile, err := os.CreateTemp("", "key") + require.NoError(t, err) + + defer os.Remove(keyFile.Name()) + + config.Server.TLS.Certificate = certFile.Name() + config.Server.TLS.Key = keyFile.Name() ValidateServer(&config, validator) require.Len(t, validator.Errors(), 0) } +func TestShouldRaiseErrorWhenTLSClientCertificateDoesNotExist(t *testing.T) { + validator := schema.NewStructValidator() + config := newDefaultConfig() + + certFile, err := os.CreateTemp("", "cert") + require.NoError(t, err) + + defer os.Remove(certFile.Name()) + + keyFile, err := os.CreateTemp("", "key") + require.NoError(t, err) + + defer os.Remove(keyFile.Name()) + + config.Server.TLS.Certificate = certFile.Name() + config.Server.TLS.Key = keyFile.Name() + config.Server.TLS.ClientCertificates = []string{"/tmp/unexisting"} + + ValidateServer(&config, validator) + require.Len(t, validator.Errors(), 1) + assert.EqualError(t, validator.Errors()[0], "server: tls: client_certificates: certificates: file path /tmp/unexisting does not exist") +} + +func TestShouldRaiseErrorWhenTLSClientAuthIsDefinedButNotServerCertificate(t *testing.T) { + validator := schema.NewStructValidator() + config := newDefaultConfig() + + certFile, err := os.CreateTemp("", "cert") + require.NoError(t, err) + + defer os.Remove(certFile.Name()) + + config.Server.TLS.ClientCertificates = []string{certFile.Name()} + + ValidateServer(&config, validator) + require.Len(t, validator.Errors(), 1) + assert.EqualError(t, validator.Errors()[0], "server: tls: client authentication cannot be configured if no server certificate and key are provided") +} + func TestShouldNotUpdateConfig(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultConfig() diff --git a/internal/server/const.go b/internal/server/const.go index 8fb605e23..a00d6b880 100644 --- a/internal/server/const.go +++ b/internal/server/const.go @@ -40,6 +40,9 @@ var ( } ) +const schemeHTTP = "http" +const schemeHTTPS = "https" + const ( dev = "dev" f = "false" diff --git a/internal/server/server.go b/internal/server/server.go index b52c87001..6bd13484c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,11 +1,15 @@ package server import ( + "crypto/tls" + "crypto/x509" "net" "os" "strconv" "time" + "github.com/authelia/authelia/v4/internal/logging" + duoapi "github.com/duosecurity/duo_api_golang" "github.com/fasthttp/router" "github.com/valyala/fasthttp" @@ -15,7 +19,6 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/duo" "github.com/authelia/authelia/v4/internal/handlers" - "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/middlewares" ) @@ -178,10 +181,8 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr return handler } -// Start Authelia's internal webserver with the given configuration and providers. -func Start(configuration schema.Configuration, providers middlewares.Providers) { - logger := logging.Logger() - +// CreateServer Create Authelia's internal webserver with the given configuration and providers. +func CreateServer(configuration schema.Configuration, providers middlewares.Providers) (*fasthttp.Server, net.Listener) { handler := registerRoutes(configuration, providers) server := &fasthttp.Server{ @@ -191,36 +192,66 @@ func Start(configuration schema.Configuration, providers middlewares.Providers) ReadBufferSize: configuration.Server.ReadBufferSize, WriteBufferSize: configuration.Server.WriteBufferSize, } + logger := logging.Logger() address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port)) - listener, err := net.Listen("tcp", address) - if err != nil { - logger.Fatalf("Error initializing listener: %s", err) - } + var ( + listener net.Listener + err error + connectionType string + connectionScheme string + ) if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" { - if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, "https", configuration.Server.Host, configuration.Server.Path, configuration.Server.Port); err != nil { - logger.Fatalf("Could not configure healthcheck: %v", err) + connectionType, connectionScheme = "TLS", schemeHTTPS + err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key) + + if err != nil { + logger.Fatalf("unable to load certificate: %v", err) } - if configuration.Server.Path == "" { - logger.Infof("Listening for TLS connections on '%s' path '/'", address) - } else { - logger.Infof("Listening for TLS connections on '%s' paths '/' and '%s'", address, configuration.Server.Path) + if len(configuration.Server.TLS.ClientCertificates) > 0 { + caCertPool := x509.NewCertPool() + + for _, path := range configuration.Server.TLS.ClientCertificates { + cert, err := os.ReadFile(path) + if err != nil { + logger.Fatalf("Cannot read client TLS certificate %s: %s", path, err) + } + + caCertPool.AppendCertsFromPEM(cert) + } + + // ClientCAs should never be nil, otherwise the system cert pool is used for client authentication + // but we don't want everybody on the Internet to be able to authenticate. + server.TLSConfig.ClientCAs = caCertPool + server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert } - logger.Fatal(server.ServeTLS(listener, configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)) + listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()) + if err != nil { + logger.Fatalf("Error initializing listener: %s", err) + } } else { - if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, "http", configuration.Server.Host, configuration.Server.Path, configuration.Server.Port); err != nil { - logger.Fatalf("Could not configure healthcheck: %v", err) + connectionType, connectionScheme = "non-TLS", schemeHTTP + listener, err = net.Listen("tcp", address) + if err != nil { + logger.Fatalf("Error initializing listener: %s", err) } - - if configuration.Server.Path == "" { - logger.Infof("Listening for non-TLS connections on '%s' path '/'", address) - } else { - logger.Infof("Listening for non-TLS connections on '%s' paths '/' and '%s'", address, configuration.Server.Path) - } - logger.Fatal(server.Serve(listener)) } + + if err = writeHealthCheckEnv(configuration.Server.DisableHealthcheck, connectionScheme, configuration.Server.Host, + configuration.Server.Path, configuration.Server.Port); err != nil { + logger.Fatalf("Could not configure healthcheck: %v", err) + } + + actualAddress := listener.Addr().String() + if configuration.Server.Path == "" { + logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress) + } else { + logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path) + } + + return server, listener } diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 000000000..dcf86d464 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,379 @@ +package server + +import ( + "crypto/elliptic" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/authelia/authelia/v4/internal/logging" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/utils" +) + +// TemporaryCertificate contains the FD of 2 temporary files containing the PEM format of the certificate and private key. +type TemporaryCertificate struct { + CertFile *os.File + KeyFile *os.File + + Certificate *x509.Certificate + + CertificatePEM []byte + KeyPEM []byte +} + +func (tc TemporaryCertificate) TLSCertificate() (tls.Certificate, error) { + return tls.LoadX509KeyPair(tc.CertFile.Name(), tc.KeyFile.Name()) +} + +func (tc *TemporaryCertificate) Close() { + if tc.CertFile != nil { + tc.CertFile.Close() + } + + if tc.KeyFile != nil { + tc.KeyFile.Close() + } +} + +type CertificateContext struct { + Certificates []TemporaryCertificate + privateKeyBuilder utils.PrivateKeyBuilder +} + +// NewCertificateContext instantiate a new certificate context used to easily generate certificates within tests. +func NewCertificateContext(privateKeyBuilder utils.PrivateKeyBuilder) (*CertificateContext, error) { + certificateContext := new(CertificateContext) + certificateContext.privateKeyBuilder = privateKeyBuilder + + cert, err := certificateContext.GenerateCertificate() + if err != nil { + return nil, err + } + + certificateContext.Certificates = []TemporaryCertificate{*cert} + + return certificateContext, nil +} + +// GenerateCertificate generate a new certificate in the context. +func (cc *CertificateContext) GenerateCertificate() (*TemporaryCertificate, error) { + certBytes, keyBytes, err := utils.GenerateCertificate(cc.privateKeyBuilder, + []string{"authelia.com", "example.org", "local.example.com"}, + time.Now(), 3*time.Hour, false) + if err != nil { + return nil, fmt.Errorf("unable to generate certificate: %v", err) + } + + tmpCertificate := new(TemporaryCertificate) + + certFile, err := os.CreateTemp("", "cert") + if err != nil { + return nil, fmt.Errorf("unable to create temp file for certificate: %v", err) + } + + tmpCertificate.CertFile = certFile + tmpCertificate.CertificatePEM = certBytes + + block, _ := pem.Decode(certBytes) + c, err := x509.ParseCertificate(block.Bytes) + + if err != nil { + return nil, fmt.Errorf("unable to parse certificate: %v", err) + } + + tmpCertificate.Certificate = c + + err = os.WriteFile(tmpCertificate.CertFile.Name(), certBytes, 0600) + if err != nil { + tmpCertificate.Close() + return nil, fmt.Errorf("unable to write certificates in file: %v", err) + } + + keyFile, err := os.CreateTemp("", "key") + if err != nil { + tmpCertificate.Close() + return nil, fmt.Errorf("unable to create temp file for private key: %v", err) + } + + tmpCertificate.KeyFile = keyFile + tmpCertificate.KeyPEM = keyBytes + + err = os.WriteFile(tmpCertificate.KeyFile.Name(), keyBytes, 0600) + if err != nil { + tmpCertificate.Close() + return nil, fmt.Errorf("unable to write private key in file: %v", err) + } + + cc.Certificates = append(cc.Certificates, *tmpCertificate) + + return tmpCertificate, nil +} + +func (cc *CertificateContext) Close() { + for _, tc := range cc.Certificates { + tc.Close() + } +} + +type TLSServerContext struct { + server *fasthttp.Server + port int +} + +func NewTLSServerContext(configuration schema.Configuration) (*TLSServerContext, error) { + serverContext := new(TLSServerContext) + + s, listener := CreateServer(configuration, middlewares.Providers{}) + serverContext.server = s + + go func() { + err := s.Serve(listener) + if err != nil { + logging.Logger().Fatal(err) + } + }() + + addrSplit := strings.Split(listener.Addr().String(), ":") + if len(addrSplit) > 1 { + port, err := strconv.ParseInt(addrSplit[len(addrSplit)-1], 10, 32) + if err != nil { + return nil, fmt.Errorf("unable to parse port from address: %v", err) + } + + serverContext.port = int(port) + } + + return serverContext, nil +} + +func (sc *TLSServerContext) Port() int { + return sc.port +} + +func (sc *TLSServerContext) Close() error { + return sc.server.Shutdown() +} + +func TestShouldRaiseErrorWhenClientDoesNotSkipVerify(t *testing.T) { + privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) + certificateContext, err := NewCertificateContext(privateKeyBuilder) + require.NoError(t, err) + + defer certificateContext.Close() + + tlsServerContext, err := NewTLSServerContext(schema.Configuration{ + Server: schema.ServerConfiguration{ + TLS: schema.ServerTLSConfiguration{ + Certificate: certificateContext.Certificates[0].CertFile.Name(), + Key: certificateContext.Certificates[0].KeyFile.Name(), + }, + }, + }) + require.NoError(t, err) + + defer tlsServerContext.Close() + + fmt.Println(tlsServerContext.Port()) + req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d", tlsServerContext.Port()), nil) + require.NoError(t, err) + + _, err = http.DefaultClient.Do(req) + require.Error(t, err) + + require.Contains(t, err.Error(), "x509: certificate signed by unknown authority") +} + +func TestShouldServeOverTLSWhenClientDoesSkipVerify(t *testing.T) { + privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) + certificateContext, err := NewCertificateContext(privateKeyBuilder) + require.NoError(t, err) + + defer certificateContext.Close() + + tlsServerContext, err := NewTLSServerContext(schema.Configuration{ + Server: schema.ServerConfiguration{ + TLS: schema.ServerTLSConfiguration{ + Certificate: certificateContext.Certificates[0].CertFile.Name(), + Key: certificateContext.Certificates[0].KeyFile.Name(), + }, + }, + }) + require.NoError(t, err) + + defer tlsServerContext.Close() + + req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil) + require.NoError(t, err) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in tests. Not used in production. + } + client := &http.Client{Transport: tr} + + res, err := client.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, "404 Not Found", res.Status) +} + +func TestShouldServeOverTLSWhenClientHasProperRootCA(t *testing.T) { + privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) + certificateContext, err := NewCertificateContext(privateKeyBuilder) + require.NoError(t, err) + + defer certificateContext.Close() + + tlsServerContext, err := NewTLSServerContext(schema.Configuration{ + Server: schema.ServerConfiguration{ + TLS: schema.ServerTLSConfiguration{ + Certificate: certificateContext.Certificates[0].CertFile.Name(), + Key: certificateContext.Certificates[0].KeyFile.Name(), + }, + }, + }) + require.NoError(t, err) + + defer tlsServerContext.Close() + + req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil) + require.NoError(t, err) + + block, _ := pem.Decode(certificateContext.Certificates[0].CertificatePEM) + c, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + // Create a root CA for the client to properly validate server cert. + rootCAs := x509.NewCertPool() + rootCAs.AddCert(c) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + MinVersion: tls.VersionTLS13, + }, + } + client := &http.Client{Transport: tr} + + res, err := client.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, "404 Not Found", res.Status) +} + +func TestShouldRaiseWhenMutualTLSIsConfiguredAndClientIsNotAuthenticated(t *testing.T) { + privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) + certificateContext, err := NewCertificateContext(privateKeyBuilder) + require.NoError(t, err) + + defer certificateContext.Close() + + clientCert, err := certificateContext.GenerateCertificate() + require.NoError(t, err) + + tlsServerContext, err := NewTLSServerContext(schema.Configuration{ + Server: schema.ServerConfiguration{ + TLS: schema.ServerTLSConfiguration{ + Certificate: certificateContext.Certificates[0].CertFile.Name(), + Key: certificateContext.Certificates[0].KeyFile.Name(), + ClientCertificates: []string{clientCert.CertFile.Name()}, + }, + }, + }) + require.NoError(t, err) + + defer tlsServerContext.Close() + + req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil) + require.NoError(t, err) + + // Create a root CA for the client to properly validate server cert. + rootCAs := x509.NewCertPool() + rootCAs.AddCert(certificateContext.Certificates[0].Certificate) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + MinVersion: tls.VersionTLS13, + }, + } + client := &http.Client{Transport: tr} + + _, err = client.Do(req) + require.Error(t, err) + assert.Contains(t, err.Error(), "remote error: tls: bad certificate") +} + +func TestShouldServeProperlyWhenMutualTLSIsConfiguredAndClientIsAuthenticated(t *testing.T) { + privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256()) + certificateContext, err := NewCertificateContext(privateKeyBuilder) + require.NoError(t, err) + + defer certificateContext.Close() + + clientCert, err := certificateContext.GenerateCertificate() + require.NoError(t, err) + + tlsServerContext, err := NewTLSServerContext(schema.Configuration{ + Server: schema.ServerConfiguration{ + TLS: schema.ServerTLSConfiguration{ + Certificate: certificateContext.Certificates[0].CertFile.Name(), + Key: certificateContext.Certificates[0].KeyFile.Name(), + ClientCertificates: []string{clientCert.CertFile.Name()}, + }, + }, + }) + require.NoError(t, err) + + defer tlsServerContext.Close() + + req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil) + require.NoError(t, err) + + // Create a root CA for the client to properly validate server cert. + rootCAs := x509.NewCertPool() + rootCAs.AddCert(certificateContext.Certificates[0].Certificate) + + cCert, err := certificateContext.Certificates[1].TLSCertificate() + require.NoError(t, err) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + Certificates: []tls.Certificate{cCert}, + MinVersion: tls.VersionTLS13, + }, + } + client := &http.Client{Transport: tr} + + res, err := client.Do(req) + require.NoError(t, err) + + defer res.Body.Close() + + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, "404 Not Found", res.Status) +} diff --git a/internal/suites/environment.go b/internal/suites/environment.go index da70eed8f..bbdbd5909 100644 --- a/internal/suites/environment.go +++ b/internal/suites/environment.go @@ -46,7 +46,7 @@ func waitUntilAutheliaBackendIsReady(dockerEnvironment *DockerEnvironment) error 90*time.Second, dockerEnvironment, "authelia-backend", - []string{"Listening for"}) + []string{"Initializing server for"}) } func waitUntilAutheliaFrontendIsReady(dockerEnvironment *DockerEnvironment) error { diff --git a/internal/suites/suite_cli_test.go b/internal/suites/suite_cli_test.go index d6e322add..846ce23fe 100644 --- a/internal/suites/suite_cli_test.go +++ b/internal/suites/suite_cli_test.go @@ -45,6 +45,10 @@ func (s *CLISuite) SetupTest() { } func (s *CLISuite) TestShouldPrintBuildInformation() { + if os.Getenv("CI") == "false" { + s.T().Skip("Skipping testing in dev environment") + } + output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "build-info"}) s.Assert().NoError(err) s.Assert().Contains(output, "Last Tag: ") @@ -92,22 +96,22 @@ func (s *CLISuite) TestShouldHashPasswordSHA512() { func (s *CLISuite) TestShouldGenerateCertificateRSA() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateRSAWithIPAddress() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=127.0.0.1", "--dir=/tmp/"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateRSAWithStartDate() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--start-date='Jan 1 15:04:05 2011'"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldFailGenerateCertificateRSAWithStartDate() { @@ -119,15 +123,15 @@ func (s *CLISuite) TestShouldFailGenerateCertificateRSAWithStartDate() { func (s *CLISuite) TestShouldGenerateCertificateCA() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ca"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateEd25519() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ed25519"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldFailGenerateCertificateECDSA() { @@ -139,29 +143,29 @@ func (s *CLISuite) TestShouldFailGenerateCertificateECDSA() { func (s *CLISuite) TestShouldGenerateCertificateECDSAP224() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P224"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateECDSAP256() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P256"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateECDSAP384() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P384"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestShouldGenerateCertificateECDSAP521() { output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "certificates", "generate", "--host=*.example.com", "--dir=/tmp/", "--ecdsa-curve=P521"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Certificate Public Key written to /tmp/cert.pem") - s.Assert().Contains(output, "Certificate Private Key written to /tmp/key.pem") + s.Assert().Contains(output, "Certificate written to /tmp/cert.pem") + s.Assert().Contains(output, "Private Key written to /tmp/key.pem") } func (s *CLISuite) TestStorageShouldShowErrWithoutConfig() { diff --git a/internal/utils/certificates.go b/internal/utils/certificates.go index 523dd28cd..bc7d2952a 100644 --- a/internal/utils/certificates.go +++ b/internal/utils/certificates.go @@ -1,17 +1,38 @@ package utils import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" + "math/big" + "net" "os" "path/filepath" "strings" + "time" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" ) +// PEMBlockType represent an enum of the existing PEM block types. +type PEMBlockType int + +const ( + // Certificate block type. + Certificate PEMBlockType = iota + // PrivateKey block type. + PrivateKey +) + // NewTLSConfig generates a tls.Config from a schema.TLSConfig and a x509.CertPool. func NewTLSConfig(config *schema.TLSConfig, defaultMinVersion uint16, certPool *x509.CertPool) (tlsConfig *tls.Config) { minVersion, err := TLSStringToTLSConfigVersion(config.MinimumVersion) @@ -83,3 +104,150 @@ func TLSStringToTLSConfigVersion(input string) (version uint16, err error) { return 0, ErrTLSVersionNotSupported } + +// GenerateCertificate generate a certificate given a private key. RSA, Ed25519 and ECDSA are officially supported. +func GenerateCertificate(privateKeyBuilder PrivateKeyBuilder, hosts []string, validFrom time.Time, validFor time.Duration, isCA bool) ([]byte, []byte, error) { + privateKey, err := privateKeyBuilder.Build() + if err != nil { + return nil, nil, fmt.Errorf("unable to build private key: %w", err) + } + + notBefore := validFrom + notAfter := validFrom.Add(validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + if isCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + } + + certDERBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %v", err) + } + + certPEMBytes, err := ConvertDERToPEM(certDERBytes, Certificate) + if err != nil { + return nil, nil, fmt.Errorf("faile to convert certificate in DER format into PEM: %v", err) + } + + keyDERBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %v", err) + } + + keyPEMBytes, err := ConvertDERToPEM(keyDERBytes, PrivateKey) + if err != nil { + return nil, nil, fmt.Errorf("faile to convert certificate in DER format into PEM: %v", err) + } + + return certPEMBytes, keyPEMBytes, nil +} + +// ConvertDERToPEM convert certificate in DER format into PEM format. +func ConvertDERToPEM(der []byte, blockType PEMBlockType) ([]byte, error) { + var buf bytes.Buffer + + var blockTypeStr string + + switch blockType { + case Certificate: + blockTypeStr = "CERTIFICATE" + case PrivateKey: + blockTypeStr = "PRIVATE KEY" + default: + return nil, fmt.Errorf("unknown PEM block type %d", blockType) + } + + if err := pem.Encode(&buf, &pem.Block{Type: blockTypeStr, Bytes: der}); err != nil { + return nil, fmt.Errorf("failed to encode DER data into PEM: %v", err) + } + + return buf.Bytes(), nil +} + +func publicKey(privateKey interface{}) interface{} { + switch k := privateKey.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + case ed25519.PrivateKey: + return k.Public().(ed25519.PublicKey) + default: + return nil + } +} + +// PrivateKeyBuilder interface for a private key builder. +type PrivateKeyBuilder interface { + Build() (interface{}, error) +} + +// RSAKeyBuilder builder of RSA private key. +type RSAKeyBuilder struct { + keySizeInBits int +} + +// WithKeySize configure the key size to use with RSA. +func (rkb RSAKeyBuilder) WithKeySize(bits int) RSAKeyBuilder { + rkb.keySizeInBits = bits + return rkb +} + +// Build a RSA private key. +func (rkb RSAKeyBuilder) Build() (interface{}, error) { + return rsa.GenerateKey(rand.Reader, rkb.keySizeInBits) +} + +// Ed25519KeyBuilder builder of Ed25519 private key. +type Ed25519KeyBuilder struct{} + +// Build an Ed25519 private key. +func (ekb Ed25519KeyBuilder) Build() (interface{}, error) { + _, priv, err := ed25519.GenerateKey(rand.Reader) + return priv, err +} + +// ECDSAKeyBuilder builder of ECDSA private key. +type ECDSAKeyBuilder struct { + curve elliptic.Curve +} + +// WithCurve configure the curve to use for the ECDSA private key. +func (ekb ECDSAKeyBuilder) WithCurve(curve elliptic.Curve) ECDSAKeyBuilder { + ekb.curve = curve + return ekb +} + +// Build an ECDSA private key. +func (ekb ECDSAKeyBuilder) Build() (interface{}, error) { + return ecdsa.GenerateKey(ekb.curve, rand.Reader) +} diff --git a/internal/utils/certificates_test.go b/internal/utils/certificates_test.go index 2cf103076..918678fe8 100644 --- a/internal/utils/certificates_test.go +++ b/internal/utils/certificates_test.go @@ -1,9 +1,11 @@ package utils import ( + "crypto/elliptic" "crypto/tls" "runtime" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -124,3 +126,44 @@ func TestShouldReadCertsFromDirectoryButNotKeys(t *testing.T) { assert.EqualError(t, errors[0], "could not import certificate key.pem") } + +func TestShouldGenerateCertificateAndPersistIt(t *testing.T) { + testCases := []struct { + Name string + PrivateKeyBuilder PrivateKeyBuilder + }{ + { + Name: "P224", + PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P224()), + }, + { + Name: "P256", + PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P256()), + }, + { + Name: "P384", + PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P384()), + }, + { + Name: "P521", + PrivateKeyBuilder: ECDSAKeyBuilder{}.WithCurve(elliptic.P521()), + }, + { + Name: "Ed25519", + PrivateKeyBuilder: Ed25519KeyBuilder{}, + }, + { + Name: "RSA", + PrivateKeyBuilder: RSAKeyBuilder{keySizeInBits: 2048}, + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + certBytes, keyBytes, err := GenerateCertificate(tc.PrivateKeyBuilder, []string{"authelia.com", "example.org"}, time.Now(), 3*time.Hour, false) + require.NoError(t, err) + assert.True(t, len(certBytes) > 0) + assert.True(t, len(keyBytes) > 0) + }) + } +}