refactor(server): simplify templating and url derivation (#4547)

This refactors a few areas of the server templating and related functions.
pull/4585/head
James Elliott 2022-12-17 11:49:05 +11:00 committed by GitHub
parent 3de693623e
commit d13247ce43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 180 additions and 188 deletions

View File

@ -1,6 +1,13 @@
--- ---
extends: default extends: default
locale: en_US.UTF-8
yaml-files:
- '*.yaml'
- '*.yml'
- '.yamllint'
ignore: | ignore: |
docs/pnpm-lock.yaml docs/pnpm-lock.yaml
internal/configuration/test_resources/config_bad_quoting.yml internal/configuration/test_resources/config_bad_quoting.yml

View File

@ -52,13 +52,7 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr
return return
} }
if issuer, err = ctx.IssuerURL(); err != nil { issuer = ctx.RootURL()
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred determining issuer: %+v", requester.GetID(), clientID, err)
ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, oidc.ErrIssuerCouldNotDerive)
return
}
userSession := ctx.GetSession() userSession := ctx.GetSession()

View File

@ -130,12 +130,7 @@ func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
query url.Values query url.Values
) )
if redirectURI, err = ctx.IssuerURL(); err != nil { redirectURI = ctx.RootURL()
ctx.Logger.Errorf("Failed to parse the consent redirect URL: %+v", err)
ctx.SetJSONError(messageOperationFailed)
return
}
if query, err = url.ParseQuery(consent.Form); err != nil { if query, err = url.ParseQuery(consent.Form); err != nil {
ctx.Logger.Errorf("Failed to parse the consent form values: %+v", err) ctx.Logger.Errorf("Failed to parse the consent form values: %+v", err)

View File

@ -20,13 +20,7 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
err error err error
) )
if issuer, err = ctx.IssuerURL(); err != nil { issuer = ctx.RootURL()
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer.String()) wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer.String())
@ -52,13 +46,7 @@ func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
err error err error
) )
if issuer, err = ctx.IssuerURL(); err != nil { issuer = ctx.RootURL()
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer.String()) wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer.String())

View File

@ -144,11 +144,7 @@ func handleOIDCWorkflowResponseWithTargetURL(ctx *middlewares.AutheliaCtx, targe
return return
} }
if issuerURL, err = ctx.IssuerURL(); err != nil { issuerURL = ctx.RootURL()
ctx.Error(fmt.Errorf("unable to get issuer for redirection: %w", err), messageAuthenticationFailed)
return
}
if targetURL.Host != issuerURL.Host { if targetURL.Host != issuerURL.Host {
ctx.Error(fmt.Errorf("unable to redirect to '%s': target host '%s' does not match expected issuer host '%s'", targetURL, targetURL.Host, issuerURL.Host), messageAuthenticationFailed) ctx.Error(fmt.Errorf("unable to redirect to '%s': target host '%s' does not match expected issuer host '%s'", targetURL, targetURL.Host, issuerURL.Host), messageAuthenticationFailed)
@ -221,11 +217,7 @@ func handleOIDCWorkflowResponseWithID(ctx *middlewares.AutheliaCtx, id string) {
form url.Values form url.Values
) )
if targetURL, err = ctx.IssuerURL(); err != nil { targetURL = ctx.RootURL()
ctx.Error(fmt.Errorf("unable to get issuer for redirection: %w", err), messageAuthenticationFailed)
return
}
if form, err = consent.GetForm(); err != nil { if form, err = consent.GetForm(); err != nil {
ctx.Error(fmt.Errorf("unable to get authorization form values from consent session with challenge id '%s': %w", consent.ChallengeID, err), messageAuthenticationFailed) ctx.Error(fmt.Errorf("unable to get authorization form values from consent session with challenge id '%s': %w", consent.ChallengeID, err), messageAuthenticationFailed)

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"path"
"strings" "strings"
"github.com/asaskevich/govalidator" "github.com/asaskevich/govalidator"
@ -81,7 +80,7 @@ func (ctx *AutheliaCtx) ReplyError(err error, message string) {
ctx.Logger.Error(marshalErr) ctx.Logger.Error(marshalErr)
} }
ctx.SetContentTypeBytes(contentTypeApplicationJSON) ctx.SetContentTypeApplicationJSON()
ctx.SetBody(b) ctx.SetBody(b)
ctx.Logger.Debug(err) ctx.Logger.Debug(err)
} }
@ -90,7 +89,7 @@ func (ctx *AutheliaCtx) ReplyError(err error, message string) {
func (ctx *AutheliaCtx) ReplyStatusCode(statusCode int) { func (ctx *AutheliaCtx) ReplyStatusCode(statusCode int) {
ctx.Response.Reset() ctx.Response.Reset()
ctx.SetStatusCode(statusCode) ctx.SetStatusCode(statusCode)
ctx.SetContentTypeBytes(contentTypeTextPlain) ctx.SetContentTypeTextPlain()
ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode))) ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode)))
} }
@ -108,7 +107,7 @@ func (ctx *AutheliaCtx) ReplyJSON(data any, statusCode int) (err error) {
ctx.SetStatusCode(statusCode) ctx.SetStatusCode(statusCode)
} }
ctx.SetContentTypeBytes(contentTypeApplicationJSON) ctx.SetContentTypeApplicationJSON()
ctx.SetBody(body) ctx.SetBody(body)
return nil return nil
@ -145,7 +144,7 @@ func (ctx *AutheliaCtx) XForwardedProto() (proto []byte) {
} }
// XForwardedMethod return the content of the X-Forwarded-Method header. // XForwardedMethod return the content of the X-Forwarded-Method header.
func (ctx *AutheliaCtx) XForwardedMethod() (method []byte) { func (ctx *AutheliaCtx) XForwardedMethod() []byte {
return ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod) return ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
} }
@ -171,79 +170,61 @@ func (ctx *AutheliaCtx) XForwardedURI() (uri []byte) {
return uri return uri
} }
// XAutheliaURL return the content of the X-Authelia-URL header. // XOriginalURL returns the content of the X-Original-URL header.
func (ctx *AutheliaCtx) XAutheliaURL() (autheliaURL []byte) { func (ctx *AutheliaCtx) XOriginalURL() []byte {
return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
}
// XOriginalMethod return the content of the X-Original-Method header.
func (ctx *AutheliaCtx) XOriginalMethod() []byte {
return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalMethod)
}
// XAutheliaURL return the content of the X-Authelia-URL header which is used to communicate the location of the
// portal when using proxies like Envoy.
func (ctx *AutheliaCtx) XAutheliaURL() []byte {
return ctx.RequestCtx.Request.Header.PeekBytes(headerXAutheliaURL) return ctx.RequestCtx.Request.Header.PeekBytes(headerXAutheliaURL)
} }
// QueryArgRedirect return the content of the rd query argument. // QueryArgRedirect return the content of the rd query argument.
func (ctx *AutheliaCtx) QueryArgRedirect() (val []byte) { func (ctx *AutheliaCtx) QueryArgRedirect() []byte {
return ctx.RequestCtx.QueryArgs().PeekBytes(queryArgRedirect) return ctx.RequestCtx.QueryArgs().PeekBytes(qryArgRedirect)
} }
// BasePath returns the base_url as per the path visited by the client. // BasePath returns the base_url as per the path visited by the client.
func (ctx *AutheliaCtx) BasePath() (base string) { func (ctx *AutheliaCtx) BasePath() string {
if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil { if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil {
return baseURL.(string) return baseURL.(string)
} }
return base return ""
} }
// ExternalRootURL gets the X-Forwarded-Proto, X-Forwarded-Host headers and the BasePath and forms them into a URL. // BasePathSlash is the same as BasePath but returns a final slash as well.
func (ctx *AutheliaCtx) ExternalRootURL() (string, error) { func (ctx *AutheliaCtx) BasePathSlash() string {
protocol := ctx.XForwardedProto() if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil {
if protocol == nil { return baseURL.(string) + strSlash
return "", errMissingXForwardedProto
} }
host := ctx.XForwardedHost() return strSlash
if host == nil {
return "", errMissingXForwardedHost
}
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
if base := ctx.BasePath(); base != "" {
externalBaseURL, err := url.ParseRequestURI(externalRootURL)
if err != nil {
return "", err
}
externalBaseURL.Path = path.Join(externalBaseURL.Path, base)
return externalBaseURL.String(), nil
}
return externalRootURL, nil
} }
// IssuerURL returns the expected Issuer. // RootURL returns the Root URL.
func (ctx *AutheliaCtx) IssuerURL() (issuerURL *url.URL, err error) { func (ctx *AutheliaCtx) RootURL() (issuerURL *url.URL) {
issuerURL = &url.URL{ return &url.URL{
Scheme: "https", Scheme: string(ctx.XForwardedProto()),
Host: string(ctx.XForwardedHost()),
Path: ctx.BasePath(),
} }
if scheme := ctx.XForwardedProto(); scheme != nil {
issuerURL.Scheme = string(scheme)
}
if host := ctx.XForwardedHost(); len(host) != 0 {
issuerURL.Host = string(host)
} else {
return nil, errMissingXForwardedHost
}
if base := ctx.BasePath(); base != "" {
issuerURL.Path = path.Join(issuerURL.Path, base)
}
return issuerURL, nil
} }
// XOriginalURL return the content of the X-Original-URL header. // RootURLSlash is the same as RootURL but includes a final slash as well.
func (ctx *AutheliaCtx) XOriginalURL() []byte { func (ctx *AutheliaCtx) RootURLSlash() (issuerURL *url.URL) {
return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL) return &url.URL{
Scheme: string(ctx.XForwardedProto()),
Host: string(ctx.XForwardedHost()),
Path: ctx.BasePathSlash(),
}
} }
// GetSession return the user session. Any update will be saved in cache. // GetSession return the user session. Any update will be saved in cache.
@ -264,7 +245,7 @@ func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error {
// ReplyOK is a helper method to reply ok. // ReplyOK is a helper method to reply ok.
func (ctx *AutheliaCtx) ReplyOK() { func (ctx *AutheliaCtx) ReplyOK() {
ctx.SetContentTypeBytes(contentTypeApplicationJSON) ctx.SetContentTypeApplicationJSON()
ctx.SetBody(okMessageBytes) ctx.SetBody(okMessageBytes)
} }
@ -377,7 +358,7 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
statusCode = fasthttp.StatusFound statusCode = fasthttp.StatusFound
} }
ctx.SetContentTypeBytes(contentTypeTextHTML) ctx.SetContentTypeTextHTML()
ctx.SetStatusCode(statusCode) ctx.SetStatusCode(statusCode)
u := fasthttp.AcquireURI() u := fasthttp.AcquireURI()
@ -400,3 +381,18 @@ func (ctx *AutheliaCtx) RecordAuthentication(success, regulated bool, method str
ctx.Providers.Metrics.RecordAuthentication(success, regulated, method) ctx.Providers.Metrics.RecordAuthentication(success, regulated, method)
} }
// SetContentTypeTextPlain efficiently sets the Content-Type header to 'text/plain; charset=utf-8'.
func (ctx *AutheliaCtx) SetContentTypeTextPlain() {
ctx.SetContentTypeBytes(contentTypeTextPlain)
}
// SetContentTypeTextHTML efficiently sets the Content-Type header to 'text/html; charset=utf-8'.
func (ctx *AutheliaCtx) SetContentTypeTextHTML() {
ctx.SetContentTypeBytes(contentTypeTextHTML)
}
// SetContentTypeApplicationJSON efficiently sets the Content-Type header to 'application/json; charset=utf-8'.
func (ctx *AutheliaCtx) SetContentTypeApplicationJSON() {
ctx.SetContentTypeBytes(contentTypeApplicationJSON)
}

View File

@ -21,7 +21,6 @@ func TestIssuerURL(t *testing.T) {
name string name string
proto, host, base string proto, host, base string
expected string expected string
err string
}{ }{
{ {
name: "Standard", name: "Standard",
@ -36,7 +35,7 @@ func TestIssuerURL(t *testing.T) {
{ {
name: "NoHost", name: "NoHost",
proto: "https", host: "", base: "", proto: "https", host: "", base: "",
err: "Missing header X-Forwarded-Host", expected: "https:",
}, },
} }
@ -52,21 +51,14 @@ func TestIssuerURL(t *testing.T) {
mock.Ctx.SetUserValue("base_url", tc.base) mock.Ctx.SetUserValue("base_url", tc.base)
} }
actual, err := mock.Ctx.IssuerURL() actual := mock.Ctx.RootURL()
switch tc.err {
case "":
assert.NoError(t, err)
require.NotNil(t, actual) require.NotNil(t, actual)
assert.Equal(t, tc.expected, actual.String()) assert.Equal(t, tc.expected, actual.String())
assert.Equal(t, tc.proto, actual.Scheme) assert.Equal(t, tc.proto, actual.Scheme)
assert.Equal(t, tc.host, actual.Host) assert.Equal(t, tc.host, actual.Host)
assert.Equal(t, tc.base, actual.Path) assert.Equal(t, tc.base, actual.Path)
default:
assert.EqualError(t, err, tc.err)
assert.Nil(t, actual)
}
}) })
} }
} }

View File

@ -20,6 +20,7 @@ var (
headerXForwardedURI = []byte("X-Forwarded-URI") headerXForwardedURI = []byte("X-Forwarded-URI")
headerXOriginalURL = []byte("X-Original-URL") headerXOriginalURL = []byte("X-Original-URL")
headerXOriginalMethod = []byte("X-Original-Method")
headerXForwardedMethod = []byte("X-Forwarded-Method") headerXForwardedMethod = []byte("X-Forwarded-Method")
headerVary = []byte(fasthttp.HeaderVary) headerVary = []byte(fasthttp.HeaderVary)
@ -67,13 +68,17 @@ var (
const ( const (
strProtoHTTPS = "https" strProtoHTTPS = "https"
strProtoHTTP = "http" strProtoHTTP = "http"
strSlash = "/"
queryArgRedirect = "rd"
queryArgToken = "token"
) )
var ( var (
protoHTTPS = []byte(strProtoHTTPS) protoHTTPS = []byte(strProtoHTTPS)
protoHTTP = []byte(strProtoHTTP) protoHTTP = []byte(strProtoHTTP)
queryArgRedirect = []byte("rd") qryArgRedirect = []byte(queryArgRedirect)
// UserValueKeyBaseURL is the User Value key where we store the Base URL. // UserValueKeyBaseURL is the User Value key where we store the Base URL.
UserValueKeyBaseURL = []byte("base_url") UserValueKeyBaseURL = []byte("base_url")

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/mail" "net/mail"
"path"
"time" "time"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
@ -51,7 +52,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc Tim
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, err := token.SignedString([]byte(ctx.Configuration.JWTSecret)) signedToken, err := token.SignedString([]byte(ctx.Configuration.JWTSecret))
if err != nil { if err != nil {
ctx.Error(err, messageOperationFailed) ctx.Error(err, messageOperationFailed)
return return
@ -62,23 +63,23 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc Tim
return return
} }
var (
uri string
)
if uri, err = ctx.ExternalRootURL(); err != nil {
ctx.Error(err, messageOperationFailed)
return
}
disableHTML := false disableHTML := false
if ctx.Configuration.Notifier.SMTP != nil { if ctx.Configuration.Notifier.SMTP != nil {
disableHTML = ctx.Configuration.Notifier.SMTP.DisableHTMLEmails disableHTML = ctx.Configuration.Notifier.SMTP.DisableHTMLEmails
} }
linkURL := ctx.RootURL()
query := linkURL.Query()
query.Set(queryArgToken, signedToken)
linkURL.Path = path.Join(linkURL.Path, args.TargetEndpoint)
linkURL.RawQuery = query.Encode()
values := templates.EmailIdentityVerificationValues{ values := templates.EmailIdentityVerificationValues{
Title: args.MailTitle, Title: args.MailTitle,
LinkURL: fmt.Sprintf("%s%s?token=%s", uri, args.TargetEndpoint, ss), LinkURL: linkURL.String(),
LinkText: args.MailButtonContent, LinkText: args.MailButtonContent,
DisplayName: identity.DisplayName, DisplayName: identity.DisplayName,
RemoteIP: ctx.RemoteIP().String(), RemoteIP: ctx.RemoteIP().String(),

View File

@ -91,24 +91,6 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message) assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
} }
func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
mock.Ctx.Configuration.JWTSecret = testJWTSecret
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
mock.StorageMock.EXPECT().
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(nil)
args := newArgs(defaultRetriever)
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message)
}
func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t) mock := mocks.NewMockAutheliaCtx(t)

View File

@ -11,6 +11,9 @@ const (
fileOpenAPI = "openapi.yml" fileOpenAPI = "openapi.yml"
fileIndexHTML = "index.html" fileIndexHTML = "index.html"
fileLogo = "logo.png" fileLogo = "logo.png"
extHTML = ".html"
extJSON = ".json"
) )
var ( var (
@ -47,6 +50,7 @@ var (
) )
const ( const (
environment = "ENVIRONMENT"
dev = "dev" dev = "dev"
f = "false" f = "false"
t = "true" t = "true"

View File

@ -3,7 +3,6 @@ package server
import ( import (
"net" "net"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@ -92,21 +91,11 @@ func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
} }
func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { func handleRouter(config schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
rememberMe := strconv.FormatBool(config.Session.RememberMeDuration != schema.RememberMeDisabled) optsTemplatedFile := NewTemplatedFileOptions(&config)
resetPassword := strconv.FormatBool(!config.AuthenticationBackend.PasswordReset.Disable)
resetPasswordCustomURL := config.AuthenticationBackend.PasswordReset.CustomURL.String() serveIndexHandler := ServeTemplatedFile(assetsRoot, fileIndexHTML, optsTemplatedFile)
serveSwaggerHandler := ServeTemplatedFile(assetsSwagger, fileIndexHTML, optsTemplatedFile)
duoSelfEnrollment := f serveSwaggerAPIHandler := ServeTemplatedFile(assetsSwagger, fileOpenAPI, optsTemplatedFile)
if !config.DuoAPI.Disable {
duoSelfEnrollment = strconv.FormatBool(config.DuoAPI.EnableSelfEnrollment)
}
https := config.Server.TLS.Key != "" && config.Server.TLS.Certificate != ""
serveIndexHandler := ServeTemplatedFile(assetsRoot, fileIndexHTML, config.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, config.Session.Name, config.Theme, https)
serveSwaggerHandler := ServeTemplatedFile(assetsSwagger, fileIndexHTML, config.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, config.Session.Name, config.Theme, https)
serveSwaggerAPIHandler := ServeTemplatedFile(assetsSwagger, fileOpenAPI, config.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, config.Session.Name, config.Theme, https)
handlerPublicHTML := newPublicHTMLEmbeddedHandler() handlerPublicHTML := newPublicHTMLEmbeddedHandler()
handlerLocales := newLocalesEmbeddedHandler() handlerLocales := newLocalesEmbeddedHandler()
@ -115,7 +104,7 @@ func handleRouter(config schema.Configuration, providers middlewares.Providers)
WithPreMiddlewares(middlewares.SecurityHeaders).Build() WithPreMiddlewares(middlewares.SecurityHeaders).Build()
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder(). policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
WithAllowedMethods("OPTIONS", "GET"). WithAllowedMethods(fasthttp.MethodOptions, fasthttp.MethodGet).
WithAllowedOrigins("*"). WithAllowedOrigins("*").
Build() Build()

View File

@ -6,11 +6,13 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"text/template" "text/template"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
@ -19,7 +21,7 @@ import (
// ServeTemplatedFile serves a templated version of a specified file, // ServeTemplatedFile serves a templated version of a specified file,
// this is utilised to pass information between the backend and frontend // this is utilised to pass information between the backend and frontend
// and generate a nonce to support a restrictive CSP while using material-ui. // and generate a nonce to support a restrictive CSP while using material-ui.
func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, session, theme string, https bool) middlewares.RequestHandler { func ServeTemplatedFile(publicDir, file string, opts *TemplatedFileOptions) middlewares.RequestHandler {
logger := logging.Logger() logger := logging.Logger()
a, err := assets.Open(path.Join(publicDir, file)) a, err := assets.Open(path.Join(publicDir, file))
@ -37,55 +39,40 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
logger.Fatalf("Unable to parse %s template: %s", file, err) logger.Fatalf("Unable to parse %s template: %s", file, err)
} }
return func(ctx *middlewares.AutheliaCtx) { isDevEnvironment := os.Getenv(environment) == dev
base := ""
if baseURL := ctx.UserValueBytes(middlewares.UserValueKeyBaseURL); baseURL != nil {
base = baseURL.(string)
}
return func(ctx *middlewares.AutheliaCtx) {
logoOverride := f logoOverride := f
if assetPath != "" { if opts.AssetPath != "" {
if _, err := os.Stat(filepath.Join(assetPath, fileLogo)); err == nil { if _, err = os.Stat(filepath.Join(opts.AssetPath, fileLogo)); err == nil {
logoOverride = t logoOverride = t
} }
} }
var scheme = schemeHTTPS
if !https {
proto := string(ctx.XForwardedProto())
switch proto {
case "":
break
case schemeHTTP, schemeHTTPS:
scheme = proto
}
}
baseURL := scheme + "://" + string(ctx.XForwardedHost()) + base + "/"
nonce := utils.RandomString(32, utils.CharSetAlphaNumeric, true)
switch extension := filepath.Ext(file); extension { switch extension := filepath.Ext(file); extension {
case ".html": case extHTML:
ctx.SetContentType("text/html; charset=utf-8") ctx.SetContentTypeTextHTML()
case extJSON:
ctx.SetContentTypeApplicationJSON()
default: default:
ctx.SetContentType("text/plain; charset=utf-8") ctx.SetContentTypeTextPlain()
} }
nonce := utils.RandomString(32, utils.CharSetAlphaNumeric, true)
switch { switch {
case publicDir == assetsSwagger: case publicDir == assetsSwagger:
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwagger, nonce, nonce)) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPSwagger, nonce, nonce))
case ctx.Configuration.Server.Headers.CSPTemplate != "": case ctx.Configuration.Server.Headers.CSPTemplate != "":
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, strings.ReplaceAll(ctx.Configuration.Server.Headers.CSPTemplate, placeholderCSPNonce, nonce)) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, strings.ReplaceAll(ctx.Configuration.Server.Headers.CSPTemplate, placeholderCSPNonce, nonce))
case os.Getenv("ENVIRONMENT") == dev: case isDevEnvironment:
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPDevelopment, nonce)) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPDevelopment, nonce))
default: default:
ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPDefault, nonce)) ctx.Response.Header.Add(fasthttp.HeaderContentSecurityPolicy, fmt.Sprintf(tmplCSPDefault, nonce))
} }
err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, BaseURL, CSPNonce, DuoSelfEnrollment, LogoOverride, RememberMe, ResetPassword, ResetPasswordCustomURL, Session, Theme string }{Base: base, BaseURL: baseURL, CSPNonce: nonce, DuoSelfEnrollment: duoSelfEnrollment, LogoOverride: logoOverride, RememberMe: rememberMe, ResetPassword: resetPassword, ResetPasswordCustomURL: resetPasswordCustomURL, Session: session, Theme: theme}) if err = tmpl.Execute(ctx.Response.BodyWriter(), opts.CommonData(ctx.BasePath(), ctx.RootURLSlash().String(), nonce, logoOverride)); err != nil {
if err != nil {
ctx.RequestCtx.Error("an error occurred", 503) ctx.RequestCtx.Error("an error occurred", 503)
logger.Errorf("Unable to execute template: %v", err) logger.Errorf("Unable to execute template: %v", err)
@ -128,3 +115,62 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er
return err return err
} }
// NewTemplatedFileOptions returns a new *TemplatedFileOptions.
func NewTemplatedFileOptions(config *schema.Configuration) (opts *TemplatedFileOptions) {
opts = &TemplatedFileOptions{
AssetPath: config.Server.AssetPath,
DuoSelfEnrollment: f,
RememberMe: strconv.FormatBool(config.Session.RememberMeDuration != schema.RememberMeDisabled),
ResetPassword: strconv.FormatBool(!config.AuthenticationBackend.PasswordReset.Disable),
ResetPasswordCustomURL: config.AuthenticationBackend.PasswordReset.CustomURL.String(),
Theme: config.Theme,
}
if !config.DuoAPI.Disable {
opts.DuoSelfEnrollment = strconv.FormatBool(config.DuoAPI.EnableSelfEnrollment)
}
return opts
}
// TemplatedFileOptions is a struct which is used for many templated files.
type TemplatedFileOptions struct {
AssetPath string
DuoSelfEnrollment string
RememberMe string
ResetPassword string
ResetPasswordCustomURL string
Session string
Theme string
}
// CommonData returns a TemplatedFileCommonData with the dynamic options.
func (options *TemplatedFileOptions) CommonData(base, baseURL, nonce, logoOverride string) TemplatedFileCommonData {
return TemplatedFileCommonData{
Base: base,
BaseURL: baseURL,
CSPNonce: nonce,
LogoOverride: logoOverride,
DuoSelfEnrollment: options.DuoSelfEnrollment,
RememberMe: options.RememberMe,
ResetPassword: options.ResetPassword,
ResetPasswordCustomURL: options.ResetPasswordCustomURL,
Session: options.Session,
Theme: options.Theme,
}
}
// TemplatedFileCommonData is a struct which is used for many templated files.
type TemplatedFileCommonData struct {
Base string
BaseURL string
CSPNonce string
LogoOverride string
DuoSelfEnrollment string
RememberMe string
ResetPassword string
ResetPasswordCustomURL string
Session string
Theme string
}

View File

@ -8,6 +8,7 @@
:8085 { :8085 {
log log
reverse_proxy authelia-backend:9091 { reverse_proxy authelia-backend:9091 {
header_up X-Forwarded-Proto https
import tls-transport import tls-transport
} }
} }

View File

@ -10,7 +10,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
yaml "gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/storage"