From 26236f491e6d2b16ae2bc8297e33a9dc883f44e5 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Mon, 7 Feb 2022 00:37:28 +1100 Subject: [PATCH] fix(server): use of inconsistent methods for determining origin (#2848) This unifies the methods to obtain the X-Forwarded-* header values and provides logical fallbacks. In addition, so we can ensure this functionality extends to the templated files we've converted the ServeTemplatedFile method into a function that operates as a middlewares.RequestHandler. Fixes #2765 --- .../handler_register_u2f_step1_test.go | 19 --- .../handlers/handler_sign_u2f_step1_test.go | 2 +- internal/handlers/handler_verify_test.go | 4 +- internal/middlewares/authelia_context.go | 158 ++++++++++-------- internal/middlewares/authelia_context_test.go | 47 +++++- internal/middlewares/const.go | 6 + .../middlewares/identity_verification_test.go | 18 -- internal/middlewares/strip_path.go | 2 +- internal/server/server.go | 8 +- internal/server/template.go | 19 +-- 10 files changed, 158 insertions(+), 125 deletions(-) diff --git a/internal/handlers/handler_register_u2f_step1_test.go b/internal/handlers/handler_register_u2f_step1_test.go index 6223b3189..3a3a160b6 100644 --- a/internal/handlers/handler_register_u2f_step1_test.go +++ b/internal/handlers/handler_register_u2f_step1_test.go @@ -48,25 +48,6 @@ func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt return ss, verification } -func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() { - token, verification := createToken(s.mock, "john", ActionU2FRegistration, - time.Now().Add(1*time.Minute)) - s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - - s.mock.StorageMock.EXPECT(). - FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). - Return(true, nil) - - s.mock.StorageMock.EXPECT(). - ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))). - Return(nil) - - SecondFactorU2FIdentityFinish(s.mock.Ctx) - - assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message) -} - func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() { s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") token, verification := createToken(s.mock, "john", ActionU2FRegistration, diff --git a/internal/handlers/handler_sign_u2f_step1_test.go b/internal/handlers/handler_sign_u2f_step1_test.go index 5483c8bc4..5381cbb20 100644 --- a/internal/handlers/handler_sign_u2f_step1_test.go +++ b/internal/handlers/handler_sign_u2f_step1_test.go @@ -27,7 +27,7 @@ func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() SecondFactorU2FSignGet(s.mock.Ctx) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message) + assert.Equal(s.T(), "missing header X-Forwarded-Host", s.mock.Hook.LastEntry().Message) } func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() { diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index 12a37abe8..b8f84f214 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -45,7 +45,7 @@ func TestShouldRaiseWhenNoHeaderProvidedToDetectTargetURL(t *testing.T) { defer mock.Close() _, err := mock.Ctx.GetOriginalURL() assert.Error(t, err) - assert.Equal(t, "Missing header X-Forwarded-Proto", err.Error()) + assert.Equal(t, "Missing header X-Forwarded-Host", err.Error()) } func TestShouldRaiseWhenNoXForwardedHostHeaderProvidedToDetectTargetURL(t *testing.T) { @@ -67,7 +67,7 @@ func TestShouldRaiseWhenXForwardedProtoIsNotParsable(t *testing.T) { _, err := mock.Ctx.GetOriginalURL() assert.Error(t, err) - assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local: parse \"!:;;:,://myhost.local\": invalid URI for request", err.Error()) + assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local/: parse \"!:;;:,://myhost.local/\": invalid URI for request", err.Error()) } func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) { diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index ec4b0dce7..03016a688 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -55,75 +55,97 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers) } // Error reply with an error and display the stack trace in the logs. -func (c *AutheliaCtx) Error(err error, message string) { - c.SetJSONError(message) +func (ctx *AutheliaCtx) Error(err error, message string) { + ctx.SetJSONError(message) - c.Logger.Error(err) + ctx.Logger.Error(err) } // SetJSONError sets the body of the response to an JSON error KO message. -func (c *AutheliaCtx) SetJSONError(message string) { +func (ctx *AutheliaCtx) SetJSONError(message string) { b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message}) if marshalErr != nil { - c.Logger.Error(marshalErr) + ctx.Logger.Error(marshalErr) } - c.SetContentType(contentTypeApplicationJSON) - c.SetBody(b) + ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetBody(b) } // ReplyError reply with an error but does not display any stack trace in the logs. -func (c *AutheliaCtx) ReplyError(err error, message string) { +func (ctx *AutheliaCtx) ReplyError(err error, message string) { b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message}) if marshalErr != nil { - c.Logger.Error(marshalErr) + ctx.Logger.Error(marshalErr) } - c.SetContentType(contentTypeApplicationJSON) - c.SetBody(b) - c.Logger.Debug(err) + ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetBody(b) + ctx.Logger.Debug(err) } // ReplyUnauthorized response sent when user is unauthorized. -func (c *AutheliaCtx) ReplyUnauthorized() { - c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized) +func (ctx *AutheliaCtx) ReplyUnauthorized() { + ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized) } // ReplyForbidden response sent when access is forbidden to user. -func (c *AutheliaCtx) ReplyForbidden() { - c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden) +func (ctx *AutheliaCtx) ReplyForbidden() { + ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden) } // ReplyBadRequest response sent when bad request has been sent. -func (c *AutheliaCtx) ReplyBadRequest() { - c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest) +func (ctx *AutheliaCtx) ReplyBadRequest() { + ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest) } // XForwardedProto return the content of the X-Forwarded-Proto header. -func (c *AutheliaCtx) XForwardedProto() []byte { - return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto) +func (ctx *AutheliaCtx) XForwardedProto() (proto []byte) { + proto = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto) + + if proto == nil { + if ctx.RequestCtx.IsTLS() { + return protoHTTPS + } + + return protoHTTP + } + + return proto } // XForwardedMethod return the content of the X-Forwarded-Method header. -func (c *AutheliaCtx) XForwardedMethod() []byte { - return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod) +func (ctx *AutheliaCtx) XForwardedMethod() (method []byte) { + return ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod) } // XForwardedHost return the content of the X-Forwarded-Host header. -func (c *AutheliaCtx) XForwardedHost() []byte { - return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost) +func (ctx *AutheliaCtx) XForwardedHost() (host []byte) { + host = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost) + + if host == nil { + return ctx.RequestCtx.Host() + } + + return host } // XForwardedURI return the content of the X-Forwarded-URI header. -func (c *AutheliaCtx) XForwardedURI() []byte { - return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI) +func (ctx *AutheliaCtx) XForwardedURI() (uri []byte) { + uri = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI) + + if len(uri) == 0 { + return ctx.RequestCtx.RequestURI() + } + + return uri } // BasePath returns the base_url as per the path visited by the client. -func (c *AutheliaCtx) BasePath() (base string) { - if baseURL := c.UserValue("base_url"); baseURL != nil { +func (ctx *AutheliaCtx) BasePath() (base string) { + if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil { return baseURL.(string) } @@ -131,20 +153,20 @@ func (c *AutheliaCtx) BasePath() (base string) { } // ExternalRootURL gets the X-Forwarded-Proto, X-Forwarded-Host headers and the BasePath and forms them into a URL. -func (c *AutheliaCtx) ExternalRootURL() (string, error) { - protocol := c.XForwardedProto() +func (ctx *AutheliaCtx) ExternalRootURL() (string, error) { + protocol := ctx.XForwardedProto() if protocol == nil { return "", errMissingXForwardedProto } - host := c.XForwardedHost() + host := ctx.XForwardedHost() if host == nil { return "", errMissingXForwardedHost } externalRootURL := fmt.Sprintf("%s://%s", protocol, host) - if base := c.BasePath(); base != "" { + if base := ctx.BasePath(); base != "" { externalBaseURL, err := url.Parse(externalRootURL) if err != nil { return "", err @@ -159,15 +181,15 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) { } // XOriginalURL return the content of the X-Original-URL header. -func (c *AutheliaCtx) XOriginalURL() []byte { - return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL) +func (ctx *AutheliaCtx) XOriginalURL() []byte { + return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL) } // GetSession return the user session. Any update will be saved in cache. -func (c *AutheliaCtx) GetSession() session.UserSession { - userSession, err := c.Providers.SessionProvider.GetSession(c.RequestCtx) +func (ctx *AutheliaCtx) GetSession() session.UserSession { + userSession, err := ctx.Providers.SessionProvider.GetSession(ctx.RequestCtx) if err != nil { - c.Logger.Error("Unable to retrieve user session") + ctx.Logger.Error("Unable to retrieve user session") return session.NewDefaultUserSession() } @@ -175,19 +197,19 @@ func (c *AutheliaCtx) GetSession() session.UserSession { } // SaveSession save the content of the session. -func (c *AutheliaCtx) SaveSession(userSession session.UserSession) error { - return c.Providers.SessionProvider.SaveSession(c.RequestCtx, userSession) +func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error { + return ctx.Providers.SessionProvider.SaveSession(ctx.RequestCtx, userSession) } // ReplyOK is a helper method to reply ok. -func (c *AutheliaCtx) ReplyOK() { - c.SetContentType(contentTypeApplicationJSON) - c.SetBody(okMessageBytes) +func (ctx *AutheliaCtx) ReplyOK() { + ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetBody(okMessageBytes) } // ParseBody parse the request body into the type of value. -func (c *AutheliaCtx) ParseBody(value interface{}) error { - err := json.Unmarshal(c.PostBody(), &value) +func (ctx *AutheliaCtx) ParseBody(value interface{}) error { + err := json.Unmarshal(ctx.PostBody(), &value) if err != nil { return fmt.Errorf("unable to parse body: %w", err) @@ -207,21 +229,21 @@ func (c *AutheliaCtx) ParseBody(value interface{}) error { } // SetJSONBody Set json body. -func (c *AutheliaCtx) SetJSONBody(value interface{}) error { +func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error { b, err := json.Marshal(OKResponse{Status: "OK", Data: value}) if err != nil { return fmt.Errorf("unable to marshal JSON body: %w", err) } - c.SetContentType(contentTypeApplicationJSON) - c.SetBody(b) + ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetBody(b) return nil } // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided. -func (c *AutheliaCtx) RemoteIP() net.IP { - XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor) +func (ctx *AutheliaCtx) RemoteIP() net.IP { + XForwardedFor := ctx.Request.Header.PeekBytes(headerXForwardedFor) if XForwardedFor != nil { ips := strings.Split(string(XForwardedFor), ",") @@ -230,26 +252,24 @@ func (c *AutheliaCtx) RemoteIP() net.IP { } } - return c.RequestCtx.RemoteIP() + return ctx.RequestCtx.RemoteIP() } -// GetOriginalURL extract the URL from the request headers (X-Original-URI or X-Forwarded-* headers). -func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) { - originalURL := c.XOriginalURL() +// GetOriginalURL extract the URL from the request headers (X-Original-URL or X-Forwarded-* headers). +func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) { + originalURL := ctx.XOriginalURL() if originalURL != nil { parsedURL, err := url.ParseRequestURI(string(originalURL)) if err != nil { return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err) } - c.Logger.Trace("Using X-Original-URL header content as targeted site URL") + ctx.Logger.Trace("Using X-Original-URL header content as targeted site URL") return parsedURL, nil } - forwardedProto := c.XForwardedProto() - forwardedHost := c.XForwardedHost() - forwardedURI := c.XForwardedURI() + forwardedProto, forwardedHost, forwardedURI := ctx.XForwardedProto(), ctx.XForwardedHost(), ctx.XForwardedURI() if forwardedProto == nil { return nil, errMissingXForwardedProto @@ -271,22 +291,22 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) { return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err) } - c.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " + + ctx.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " + "to construct targeted site URL") return parsedURL, nil } // IsXHR returns true if the request is a XMLHttpRequest. -func (c AutheliaCtx) IsXHR() (xhr bool) { - requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith) +func (ctx AutheliaCtx) IsXHR() (xhr bool) { + requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith) - return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR + return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR) } // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type. -func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { - accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",") +func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { + accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",") for i, accept := range accepts { mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ") @@ -300,22 +320,22 @@ func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { // SpecialRedirect performs a redirect similar to fasthttp.RequestCtx except it allows statusCode 401 and includes body // content in the form of a link to the location. -func (c *AutheliaCtx) SpecialRedirect(uri string, statusCode int) { +func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) { if statusCode < fasthttp.StatusMovedPermanently || (statusCode > fasthttp.StatusSeeOther && statusCode != fasthttp.StatusTemporaryRedirect && statusCode != fasthttp.StatusPermanentRedirect && statusCode != fasthttp.StatusUnauthorized) { statusCode = fasthttp.StatusFound } - c.SetContentType(contentTypeTextHTML) - c.SetStatusCode(statusCode) + ctx.SetContentType(contentTypeTextHTML) + ctx.SetStatusCode(statusCode) u := fasthttp.AcquireURI() - c.URI().CopyTo(u) + ctx.URI().CopyTo(u) u.Update(uri) - c.Response.Header.SetBytesV("Location", u.FullURI()) + ctx.Response.Header.SetBytesV("Location", u.FullURI()) - c.SetBodyString(fmt.Sprintf("%s", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode))) + ctx.SetBodyString(fmt.Sprintf("%s", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode))) fasthttp.ReleaseURI(u) } diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go index f5ea62813..5a26f1c78 100644 --- a/internal/middlewares/authelia_context_test.go +++ b/internal/middlewares/authelia_context_test.go @@ -57,7 +57,7 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithoutURI(t *testing.T) { originalURL, err := mock.Ctx.GetOriginalURL() assert.NoError(t, err) - expectedURL, err := url.ParseRequestURI("https://home.example.com") + expectedURL, err := url.ParseRequestURI("https://home.example.com/") assert.NoError(t, err) assert.Equal(t, expectedURL, originalURL) } @@ -70,3 +70,48 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithURI(t *testing.T) { assert.Error(t, err) assert.Equal(t, "Unable to parse URL extracted from X-Original-URL header: parse \"htt-ps//home?-.example.com\": invalid URI for request", err.Error()) } + +func TestShouldFallbackToNonXForwardedHeaders(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password") + mock.Ctx.RequestCtx.Request.SetHost("auth.example.com:1234") + + assert.Equal(t, []byte("http"), mock.Ctx.XForwardedProto()) + assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost()) + assert.Equal(t, []byte("/2fa/one-time-password"), mock.Ctx.XForwardedURI()) +} + +func TestShouldOnlyFallbackToNonXForwardedHeadersWhenNil(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password") + mock.Ctx.RequestCtx.Request.SetHost("localhost") + mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXForwardedHost, "auth.example.com:1234") + mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-URI", "/base/2fa/one-time-password") + mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Proto", "https") + mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Method", "GET") + + assert.Equal(t, []byte("https"), mock.Ctx.XForwardedProto()) + assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost()) + assert.Equal(t, []byte("/base/2fa/one-time-password"), mock.Ctx.XForwardedURI()) + assert.Equal(t, []byte("GET"), mock.Ctx.XForwardedMethod()) +} + +func TestShouldDetectXHR(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXRequestedWith, "XMLHttpRequest") + + assert.True(t, mock.Ctx.IsXHR()) +} + +func TestShouldDetectNonXHR(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + assert.False(t, mock.Ctx.IsXHR()) +} diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index 01fc3f1e1..dbe37e6f5 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -14,6 +14,12 @@ var ( headerXForwardedURI = []byte("X-Forwarded-URI") headerXOriginalURL = []byte("X-Original-URL") headerXForwardedMethod = []byte("X-Forwarded-Method") + + protoHTTPS = []byte("https") + protoHTTP = []byte("http") + + // UserValueKeyBaseURL is the User Value key where we store the Base URL. + UserValueKeyBaseURL = []byte("base_url") ) const ( diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index 38e592c74..8bfedda6e 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -90,24 +90,6 @@ func TestShouldFailSendingAnEmail(t *testing.T) { assert.Equal(t, "no notif", mock.Hook.LastEntry().Message) } -func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) { - mock := mocks.NewMockAutheliaCtx(t) - defer mock.Close() - - mock.Ctx.Configuration.JWTSecret = testJWTSecret - mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") - - mock.StorageMock.EXPECT(). - SaveIdentityVerification(mock.Ctx, gomock.Any()). - Return(nil) - - args := newArgs(defaultRetriever) - middlewares.IdentityVerificationStart(args, nil)(mock.Ctx) - - assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) - assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message) -} - func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() diff --git a/internal/middlewares/strip_path.go b/internal/middlewares/strip_path.go index 2bddcb55f..8079bb420 100644 --- a/internal/middlewares/strip_path.go +++ b/internal/middlewares/strip_path.go @@ -12,7 +12,7 @@ func StripPathMiddleware(path string, next fasthttp.RequestHandler) fasthttp.Req uri := ctx.RequestURI() if strings.HasPrefix(string(uri), path) { - ctx.SetUserValue("base_url", path) + ctx.SetUserValueBytes(UserValueKeyBaseURL, path) newURI := strings.TrimPrefix(string(uri), path) ctx.Request.SetRequestURI(newURI) diff --git a/internal/server/server.go b/internal/server/server.go index d75bb0d4f..cbc53e71b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -46,11 +46,11 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, configuration.Session.Name, configuration.Theme, https) r := router.New() - r.GET("/", serveIndexHandler) + r.GET("/", autheliaMiddleware(serveIndexHandler)) r.OPTIONS("/", autheliaMiddleware(handleOPTIONS)) - r.GET("/api/", serveSwaggerHandler) - r.GET("/api/"+apiFile, serveSwaggerAPIHandler) + r.GET("/api/", autheliaMiddleware(serveSwaggerHandler)) + r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler)) for _, f := range rootFiles { r.GET("/"+f, middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, embeddedFS)) @@ -148,7 +148,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr r.GET("/debug/vars", expvarhandler.ExpvarHandler) } - r.NotFound = serveIndexHandler + r.NotFound = autheliaMiddleware(serveIndexHandler) handler := middlewares.LogRequestMiddleware(r.Handler) if configuration.Server.Path != "" { diff --git a/internal/server/template.go b/internal/server/template.go index a95927563..e3214aca1 100644 --- a/internal/server/template.go +++ b/internal/server/template.go @@ -7,16 +7,15 @@ import ( "path/filepath" "text/template" - "github.com/valyala/fasthttp" - "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/utils" ) // ServeTemplatedFile serves a templated version of a specified file, // this is utilised to pass information between the backend and frontend // and generate a nonce to support a restrictive CSP while using material-ui. -func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) fasthttp.RequestHandler { +func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) middlewares.RequestHandler { logger := logging.Logger() a, err := assets.Open(publicDir + file) @@ -34,9 +33,9 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM logger.Fatalf("Unable to parse %s template: %s", file, err) } - return func(ctx *fasthttp.RequestCtx) { + return func(ctx *middlewares.AutheliaCtx) { base := "" - if baseURL := ctx.UserValue("base_url"); baseURL != nil { + if baseURL := ctx.UserValueBytes(middlewares.UserValueKeyBaseURL); baseURL != nil { base = baseURL.(string) } @@ -51,16 +50,16 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM var scheme = "https" if !https { - proto := string(ctx.Request.Header.Peek(fasthttp.HeaderXForwardedProto)) + proto := string(ctx.XForwardedProto()) switch proto { case "": - scheme = "http" - default: + break + case "http", "https": scheme = proto } } - baseURL := scheme + "://" + string(ctx.Request.Host()) + base + "/" + baseURL := scheme + "://" + string(ctx.XForwardedHost()) + base + "/" nonce := utils.RandomString(32, utils.AlphaNumericCharacters, true) switch extension := filepath.Ext(file); extension { @@ -81,7 +80,7 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, BaseURL, CSPNonce, DuoSelfEnrollment, LogoOverride, RememberMe, ResetPassword, Session, Theme string }{Base: base, BaseURL: baseURL, CSPNonce: nonce, DuoSelfEnrollment: duoSelfEnrollment, LogoOverride: logoOverride, RememberMe: rememberMe, ResetPassword: resetPassword, Session: session, Theme: theme}) if err != nil { - ctx.Error("an error occurred", 503) + ctx.RequestCtx.Error("an error occurred", 503) logger.Errorf("Unable to execute template: %v", err) return