authelia/internal/server/server_test.go

384 lines
11 KiB
Go
Raw Normal View History

package server
import (
"crypto/elliptic"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/utils"
)
// TemporaryCertificate contains the FD of 2 temporary files containing the PEM format of the certificate and private key.
type TemporaryCertificate struct {
CertFile *os.File
KeyFile *os.File
Certificate *x509.Certificate
CertificatePEM []byte
KeyPEM []byte
}
func (tc TemporaryCertificate) TLSCertificate() (tls.Certificate, error) {
return tls.LoadX509KeyPair(tc.CertFile.Name(), tc.KeyFile.Name())
}
func (tc *TemporaryCertificate) Close() {
if tc.CertFile != nil {
tc.CertFile.Close()
}
if tc.KeyFile != nil {
tc.KeyFile.Close()
}
}
type CertificateContext struct {
Certificates []TemporaryCertificate
privateKeyBuilder utils.PrivateKeyBuilder
}
// NewCertificateContext instantiate a new certificate context used to easily generate certificates within tests.
func NewCertificateContext(privateKeyBuilder utils.PrivateKeyBuilder) (*CertificateContext, error) {
certificateContext := new(CertificateContext)
certificateContext.privateKeyBuilder = privateKeyBuilder
cert, err := certificateContext.GenerateCertificate()
if err != nil {
return nil, err
}
certificateContext.Certificates = []TemporaryCertificate{*cert}
return certificateContext, nil
}
// GenerateCertificate generate a new certificate in the context.
func (cc *CertificateContext) GenerateCertificate() (*TemporaryCertificate, error) {
certBytes, keyBytes, err := utils.GenerateCertificate(cc.privateKeyBuilder,
[]string{"authelia.com", "example.org", "local.example.com"},
time.Now(), 3*time.Hour, false)
if err != nil {
return nil, fmt.Errorf("unable to generate certificate: %v", err)
}
tmpCertificate := new(TemporaryCertificate)
certFile, err := os.CreateTemp("", "cert")
if err != nil {
return nil, fmt.Errorf("unable to create temp file for certificate: %v", err)
}
tmpCertificate.CertFile = certFile
tmpCertificate.CertificatePEM = certBytes
block, _ := pem.Decode(certBytes)
c, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("unable to parse certificate: %v", err)
}
tmpCertificate.Certificate = c
err = os.WriteFile(tmpCertificate.CertFile.Name(), certBytes, 0600)
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to write certificates in file: %v", err)
}
keyFile, err := os.CreateTemp("", "key")
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to create temp file for private key: %v", err)
}
tmpCertificate.KeyFile = keyFile
tmpCertificate.KeyPEM = keyBytes
err = os.WriteFile(tmpCertificate.KeyFile.Name(), keyBytes, 0600)
if err != nil {
tmpCertificate.Close()
return nil, fmt.Errorf("unable to write private key in file: %v", err)
}
cc.Certificates = append(cc.Certificates, *tmpCertificate)
return tmpCertificate, nil
}
func (cc *CertificateContext) Close() {
for _, tc := range cc.Certificates {
tc.Close()
}
}
type TLSServerContext struct {
server *fasthttp.Server
port int
}
func NewTLSServerContext(configuration schema.Configuration) (*TLSServerContext, error) {
serverContext := new(TLSServerContext)
s, listener, err := CreateDefaultServer(configuration, middlewares.Providers{})
if err != nil {
return nil, err
}
serverContext.server = s
go func() {
err := s.Serve(listener)
if err != nil {
logging.Logger().Fatal(err)
}
}()
addrSplit := strings.Split(listener.Addr().String(), ":")
if len(addrSplit) > 1 {
port, err := strconv.ParseInt(addrSplit[len(addrSplit)-1], 10, 32)
if err != nil {
return nil, fmt.Errorf("unable to parse port from address: %v", err)
}
serverContext.port = int(port)
}
return serverContext, nil
}
func (sc *TLSServerContext) Port() int {
return sc.port
}
func (sc *TLSServerContext) Close() error {
return sc.server.Shutdown()
}
func TestShouldRaiseErrorWhenClientDoesNotSkipVerify(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
fmt.Println(tlsServerContext.Port())
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d", tlsServerContext.Port()), nil)
require.NoError(t, err)
_, err = http.DefaultClient.Do(req)
require.Error(t, err)
require.Contains(t, err.Error(), "x509: certificate signed by unknown authority")
}
func TestShouldServeOverTLSWhenClientDoesSkipVerify(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in tests. Not used in production.
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}
func TestShouldServeOverTLSWhenClientHasProperRootCA(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
block, _ := pem.Decode(certificateContext.Certificates[0].CertificatePEM)
c, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(c)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}
func TestShouldRaiseWhenMutualTLSIsConfiguredAndClientIsNotAuthenticated(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
clientCert, err := certificateContext.GenerateCertificate()
require.NoError(t, err)
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
ClientCertificates: []string{clientCert.CertFile.Name()},
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificateContext.Certificates[0].Certificate)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
_, err = client.Do(req)
require.Error(t, err)
assert.Contains(t, err.Error(), "remote error: tls: bad certificate")
}
func TestShouldServeProperlyWhenMutualTLSIsConfiguredAndClientIsAuthenticated(t *testing.T) {
privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
certificateContext, err := NewCertificateContext(privateKeyBuilder)
require.NoError(t, err)
defer certificateContext.Close()
clientCert, err := certificateContext.GenerateCertificate()
require.NoError(t, err)
tlsServerContext, err := NewTLSServerContext(schema.Configuration{
Server: schema.ServerConfiguration{
TLS: schema.ServerTLSConfiguration{
Certificate: certificateContext.Certificates[0].CertFile.Name(),
Key: certificateContext.Certificates[0].KeyFile.Name(),
ClientCertificates: []string{clientCert.CertFile.Name()},
},
},
})
require.NoError(t, err)
defer tlsServerContext.Close()
req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
require.NoError(t, err)
// Create a root CA for the client to properly validate server cert.
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificateContext.Certificates[0].Certificate)
cCert, err := certificateContext.Certificates[1].TLSCertificate()
require.NoError(t, err)
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
Certificates: []tls.Certificate{cCert},
MinVersion: tls.VersionTLS13,
},
}
client := &http.Client{Transport: tr}
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
_, err = io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "404 Not Found", res.Status)
}