From f12346e39c606f3aa3fd5a754f4d00cc269a0ec6 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 25 Jul 2022 20:43:50 +1000 Subject: [PATCH] fix(server): missing cache-control header (#3737) This fixes a missing cache control header. Fixes #3732. --- internal/server/asset.go | 95 ++++++++++++++++++++++++++++++++++------ internal/server/const.go | 12 +++++ 2 files changed, 93 insertions(+), 14 deletions(-) diff --git a/internal/server/asset.go b/internal/server/asset.go index 4df0d4211..d834ccbed 100644 --- a/internal/server/asset.go +++ b/internal/server/asset.go @@ -1,15 +1,20 @@ package server import ( + "bytes" + "crypto/sha1" //nolint:gosec // Usage is for collision avoidance not security. "embed" "errors" "fmt" "io/fs" + "mime" "net/http" + "path" + "path/filepath" "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttpadaptor" + "github.com/authelia/authelia/v4/internal/handlers" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/utils" ) @@ -21,9 +26,43 @@ var locales embed.FS var assets embed.FS func newPublicHTMLEmbeddedHandler() fasthttp.RequestHandler { - embeddedPath, _ := fs.Sub(assets, "public_html") + etags := map[string][]byte{} - return fasthttpadaptor.NewFastHTTPHandler(http.FileServer(http.FS(embeddedPath))) + getEmbedETags(assets, "public_html", etags) + + return func(ctx *fasthttp.RequestCtx) { + p := path.Join("public_html", string(ctx.Path())) + + if etag, ok := etags[p]; ok { + ctx.Response.Header.SetBytesKV(headerETag, etag) + ctx.Response.Header.SetBytesKV(headerCacheControl, headerValueCacheControlETaggedAssets) + + if bytes.Equal(etag, ctx.Request.Header.PeekBytes(headerIfNoneMatch)) { + ctx.SetStatusCode(fasthttp.StatusNotModified) + + return + } + } + + var ( + data []byte + err error + ) + + if data, err = assets.ReadFile(p); err != nil { + hfsHandleErr(ctx, err) + + return + } + + contentType := mime.TypeByExtension(path.Ext(p)) + if len(contentType) == 0 { + contentType = http.DetectContentType(data) + } + + ctx.SetContentType(contentType) + ctx.SetBody(data) + } } func newLocalesEmbeddedHandler() (handler fasthttp.RequestHandler) { @@ -72,18 +111,46 @@ func newLocalesEmbeddedHandler() (handler fasthttp.RequestHandler) { } } -func hfsHandleErr(ctx *fasthttp.RequestCtx, err error) { - switch { - case errors.Is(err, fs.ErrNotExist): - writeStatus(ctx, fasthttp.StatusNotFound) - case errors.Is(err, fs.ErrPermission): - writeStatus(ctx, fasthttp.StatusForbidden) - default: - writeStatus(ctx, fasthttp.StatusInternalServerError) +func getEmbedETags(embedFS embed.FS, root string, etags map[string][]byte) { + var ( + err error + entries []fs.DirEntry + ) + + if entries, err = embedFS.ReadDir(root); err != nil { + return + } + + for _, entry := range entries { + if entry.IsDir() { + getEmbedETags(embedFS, filepath.Join(root, entry.Name()), etags) + + continue + } + + p := filepath.Join(root, entry.Name()) + + var data []byte + + if data, err = embedFS.ReadFile(p); err != nil { + continue + } + + sum := sha1.New() //nolint:gosec // Usage is for collision avoidance not security. + + sum.Write(data) + + etags[p] = []byte(fmt.Sprintf("%x", sum.Sum(nil))) } } -func writeStatus(ctx *fasthttp.RequestCtx, status int) { - ctx.SetStatusCode(status) - ctx.SetBodyString(fmt.Sprintf("%d %s", status, fasthttp.StatusMessage(status))) +func hfsHandleErr(ctx *fasthttp.RequestCtx, err error) { + switch { + case errors.Is(err, fs.ErrNotExist): + handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound) + case errors.Is(err, fs.ErrPermission): + handlers.SetStatusCodeResponse(ctx, fasthttp.StatusForbidden) + default: + handlers.SetStatusCodeResponse(ctx, fasthttp.StatusInternalServerError) + } } diff --git a/internal/server/const.go b/internal/server/const.go index ae815cd45..483a3ab0f 100644 --- a/internal/server/const.go +++ b/internal/server/const.go @@ -1,5 +1,9 @@ package server +import ( + "github.com/valyala/fasthttp" +) + const ( embeddedAssets = "public_html/" swaggerAssets = embeddedAssets + "api/" @@ -50,6 +54,14 @@ const ( schemeHTTPS = "https" ) +var ( + headerETag = []byte(fasthttp.HeaderETag) + headerIfNoneMatch = []byte(fasthttp.HeaderIfNoneMatch) + headerCacheControl = []byte(fasthttp.HeaderCacheControl) + + headerValueCacheControlETaggedAssets = []byte("public, max-age=0, must-revalidate") +) + const healthCheckEnv = `# Written by Authelia Process X_AUTHELIA_HEALTHCHECK=1 X_AUTHELIA_HEALTHCHECK_SCHEME=%s