fix(server): use of inconsistent methods for determining origin (#2848)
This unifies the methods to obtain the X-Forwarded-* header values and provides logical fallbacks. In addition, so we can ensure this functionality extends to the templated files we've converted the ServeTemplatedFile method into a function that operates as a middlewares.RequestHandler. Fixes #2765pull/2799/head^2
parent
7775d2af0e
commit
26236f491e
|
@ -48,25 +48,6 @@ func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt
|
||||||
return ss, verification
|
return ss, verification
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
|
|
||||||
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
|
||||||
time.Now().Add(1*time.Minute))
|
|
||||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
|
||||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
|
||||||
Return(true, nil)
|
|
||||||
|
|
||||||
s.mock.StorageMock.EXPECT().
|
|
||||||
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
|
||||||
Return(nil)
|
|
||||||
|
|
||||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
|
||||||
|
|
||||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
|
||||||
assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
||||||
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||||
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing()
|
||||||
SecondFactorU2FSignGet(s.mock.Ctx)
|
SecondFactorU2FSignGet(s.mock.Ctx)
|
||||||
|
|
||||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||||
assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message)
|
assert.Equal(s.T(), "missing header X-Forwarded-Host", s.mock.Hook.LastEntry().Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
||||||
|
|
|
@ -45,7 +45,7 @@ func TestShouldRaiseWhenNoHeaderProvidedToDetectTargetURL(t *testing.T) {
|
||||||
defer mock.Close()
|
defer mock.Close()
|
||||||
_, err := mock.Ctx.GetOriginalURL()
|
_, err := mock.Ctx.GetOriginalURL()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, "Missing header X-Forwarded-Proto", err.Error())
|
assert.Equal(t, "Missing header X-Forwarded-Host", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseWhenNoXForwardedHostHeaderProvidedToDetectTargetURL(t *testing.T) {
|
func TestShouldRaiseWhenNoXForwardedHostHeaderProvidedToDetectTargetURL(t *testing.T) {
|
||||||
|
@ -67,7 +67,7 @@ func TestShouldRaiseWhenXForwardedProtoIsNotParsable(t *testing.T) {
|
||||||
|
|
||||||
_, err := mock.Ctx.GetOriginalURL()
|
_, err := mock.Ctx.GetOriginalURL()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local: parse \"!:;;:,://myhost.local\": invalid URI for request", err.Error())
|
assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local/: parse \"!:;;:,://myhost.local/\": invalid URI for request", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) {
|
func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) {
|
||||||
|
|
|
@ -55,75 +55,97 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error reply with an error and display the stack trace in the logs.
|
// Error reply with an error and display the stack trace in the logs.
|
||||||
func (c *AutheliaCtx) Error(err error, message string) {
|
func (ctx *AutheliaCtx) Error(err error, message string) {
|
||||||
c.SetJSONError(message)
|
ctx.SetJSONError(message)
|
||||||
|
|
||||||
c.Logger.Error(err)
|
ctx.Logger.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetJSONError sets the body of the response to an JSON error KO message.
|
// SetJSONError sets the body of the response to an JSON error KO message.
|
||||||
func (c *AutheliaCtx) SetJSONError(message string) {
|
func (ctx *AutheliaCtx) SetJSONError(message string) {
|
||||||
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
|
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
|
||||||
|
|
||||||
if marshalErr != nil {
|
if marshalErr != nil {
|
||||||
c.Logger.Error(marshalErr)
|
ctx.Logger.Error(marshalErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetContentType(contentTypeApplicationJSON)
|
ctx.SetContentType(contentTypeApplicationJSON)
|
||||||
c.SetBody(b)
|
ctx.SetBody(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyError reply with an error but does not display any stack trace in the logs.
|
// ReplyError reply with an error but does not display any stack trace in the logs.
|
||||||
func (c *AutheliaCtx) ReplyError(err error, message string) {
|
func (ctx *AutheliaCtx) ReplyError(err error, message string) {
|
||||||
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
|
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
|
||||||
|
|
||||||
if marshalErr != nil {
|
if marshalErr != nil {
|
||||||
c.Logger.Error(marshalErr)
|
ctx.Logger.Error(marshalErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetContentType(contentTypeApplicationJSON)
|
ctx.SetContentType(contentTypeApplicationJSON)
|
||||||
c.SetBody(b)
|
ctx.SetBody(b)
|
||||||
c.Logger.Debug(err)
|
ctx.Logger.Debug(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyUnauthorized response sent when user is unauthorized.
|
// ReplyUnauthorized response sent when user is unauthorized.
|
||||||
func (c *AutheliaCtx) ReplyUnauthorized() {
|
func (ctx *AutheliaCtx) ReplyUnauthorized() {
|
||||||
c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
|
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyForbidden response sent when access is forbidden to user.
|
// ReplyForbidden response sent when access is forbidden to user.
|
||||||
func (c *AutheliaCtx) ReplyForbidden() {
|
func (ctx *AutheliaCtx) ReplyForbidden() {
|
||||||
c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
|
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyBadRequest response sent when bad request has been sent.
|
// ReplyBadRequest response sent when bad request has been sent.
|
||||||
func (c *AutheliaCtx) ReplyBadRequest() {
|
func (ctx *AutheliaCtx) ReplyBadRequest() {
|
||||||
c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
|
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// XForwardedProto return the content of the X-Forwarded-Proto header.
|
// XForwardedProto return the content of the X-Forwarded-Proto header.
|
||||||
func (c *AutheliaCtx) XForwardedProto() []byte {
|
func (ctx *AutheliaCtx) XForwardedProto() (proto []byte) {
|
||||||
return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
|
proto = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
|
||||||
|
|
||||||
|
if proto == nil {
|
||||||
|
if ctx.RequestCtx.IsTLS() {
|
||||||
|
return protoHTTPS
|
||||||
|
}
|
||||||
|
|
||||||
|
return protoHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
return proto
|
||||||
}
|
}
|
||||||
|
|
||||||
// XForwardedMethod return the content of the X-Forwarded-Method header.
|
// XForwardedMethod return the content of the X-Forwarded-Method header.
|
||||||
func (c *AutheliaCtx) XForwardedMethod() []byte {
|
func (ctx *AutheliaCtx) XForwardedMethod() (method []byte) {
|
||||||
return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
|
return ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
// XForwardedHost return the content of the X-Forwarded-Host header.
|
// XForwardedHost return the content of the X-Forwarded-Host header.
|
||||||
func (c *AutheliaCtx) XForwardedHost() []byte {
|
func (ctx *AutheliaCtx) XForwardedHost() (host []byte) {
|
||||||
return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
|
host = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
|
||||||
|
|
||||||
|
if host == nil {
|
||||||
|
return ctx.RequestCtx.Host()
|
||||||
|
}
|
||||||
|
|
||||||
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
// XForwardedURI return the content of the X-Forwarded-URI header.
|
// XForwardedURI return the content of the X-Forwarded-URI header.
|
||||||
func (c *AutheliaCtx) XForwardedURI() []byte {
|
func (ctx *AutheliaCtx) XForwardedURI() (uri []byte) {
|
||||||
return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
|
uri = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
|
||||||
|
|
||||||
|
if len(uri) == 0 {
|
||||||
|
return ctx.RequestCtx.RequestURI()
|
||||||
|
}
|
||||||
|
|
||||||
|
return uri
|
||||||
}
|
}
|
||||||
|
|
||||||
// BasePath returns the base_url as per the path visited by the client.
|
// BasePath returns the base_url as per the path visited by the client.
|
||||||
func (c *AutheliaCtx) BasePath() (base string) {
|
func (ctx *AutheliaCtx) BasePath() (base string) {
|
||||||
if baseURL := c.UserValue("base_url"); baseURL != nil {
|
if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil {
|
||||||
return baseURL.(string)
|
return baseURL.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,20 +153,20 @@ func (c *AutheliaCtx) BasePath() (base string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExternalRootURL gets the X-Forwarded-Proto, X-Forwarded-Host headers and the BasePath and forms them into a URL.
|
// ExternalRootURL gets the X-Forwarded-Proto, X-Forwarded-Host headers and the BasePath and forms them into a URL.
|
||||||
func (c *AutheliaCtx) ExternalRootURL() (string, error) {
|
func (ctx *AutheliaCtx) ExternalRootURL() (string, error) {
|
||||||
protocol := c.XForwardedProto()
|
protocol := ctx.XForwardedProto()
|
||||||
if protocol == nil {
|
if protocol == nil {
|
||||||
return "", errMissingXForwardedProto
|
return "", errMissingXForwardedProto
|
||||||
}
|
}
|
||||||
|
|
||||||
host := c.XForwardedHost()
|
host := ctx.XForwardedHost()
|
||||||
if host == nil {
|
if host == nil {
|
||||||
return "", errMissingXForwardedHost
|
return "", errMissingXForwardedHost
|
||||||
}
|
}
|
||||||
|
|
||||||
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
|
externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
|
||||||
|
|
||||||
if base := c.BasePath(); base != "" {
|
if base := ctx.BasePath(); base != "" {
|
||||||
externalBaseURL, err := url.Parse(externalRootURL)
|
externalBaseURL, err := url.Parse(externalRootURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -159,15 +181,15 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// XOriginalURL return the content of the X-Original-URL header.
|
// XOriginalURL return the content of the X-Original-URL header.
|
||||||
func (c *AutheliaCtx) XOriginalURL() []byte {
|
func (ctx *AutheliaCtx) XOriginalURL() []byte {
|
||||||
return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
|
return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSession return the user session. Any update will be saved in cache.
|
// GetSession return the user session. Any update will be saved in cache.
|
||||||
func (c *AutheliaCtx) GetSession() session.UserSession {
|
func (ctx *AutheliaCtx) GetSession() session.UserSession {
|
||||||
userSession, err := c.Providers.SessionProvider.GetSession(c.RequestCtx)
|
userSession, err := ctx.Providers.SessionProvider.GetSession(ctx.RequestCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Unable to retrieve user session")
|
ctx.Logger.Error("Unable to retrieve user session")
|
||||||
return session.NewDefaultUserSession()
|
return session.NewDefaultUserSession()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,19 +197,19 @@ func (c *AutheliaCtx) GetSession() session.UserSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession save the content of the session.
|
// SaveSession save the content of the session.
|
||||||
func (c *AutheliaCtx) SaveSession(userSession session.UserSession) error {
|
func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error {
|
||||||
return c.Providers.SessionProvider.SaveSession(c.RequestCtx, userSession)
|
return ctx.Providers.SessionProvider.SaveSession(ctx.RequestCtx, userSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyOK is a helper method to reply ok.
|
// ReplyOK is a helper method to reply ok.
|
||||||
func (c *AutheliaCtx) ReplyOK() {
|
func (ctx *AutheliaCtx) ReplyOK() {
|
||||||
c.SetContentType(contentTypeApplicationJSON)
|
ctx.SetContentType(contentTypeApplicationJSON)
|
||||||
c.SetBody(okMessageBytes)
|
ctx.SetBody(okMessageBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseBody parse the request body into the type of value.
|
// ParseBody parse the request body into the type of value.
|
||||||
func (c *AutheliaCtx) ParseBody(value interface{}) error {
|
func (ctx *AutheliaCtx) ParseBody(value interface{}) error {
|
||||||
err := json.Unmarshal(c.PostBody(), &value)
|
err := json.Unmarshal(ctx.PostBody(), &value)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to parse body: %w", err)
|
return fmt.Errorf("unable to parse body: %w", err)
|
||||||
|
@ -207,21 +229,21 @@ func (c *AutheliaCtx) ParseBody(value interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetJSONBody Set json body.
|
// SetJSONBody Set json body.
|
||||||
func (c *AutheliaCtx) SetJSONBody(value interface{}) error {
|
func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error {
|
||||||
b, err := json.Marshal(OKResponse{Status: "OK", Data: value})
|
b, err := json.Marshal(OKResponse{Status: "OK", Data: value})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to marshal JSON body: %w", err)
|
return fmt.Errorf("unable to marshal JSON body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetContentType(contentTypeApplicationJSON)
|
ctx.SetContentType(contentTypeApplicationJSON)
|
||||||
c.SetBody(b)
|
ctx.SetBody(b)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
|
// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
|
||||||
func (c *AutheliaCtx) RemoteIP() net.IP {
|
func (ctx *AutheliaCtx) RemoteIP() net.IP {
|
||||||
XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor)
|
XForwardedFor := ctx.Request.Header.PeekBytes(headerXForwardedFor)
|
||||||
if XForwardedFor != nil {
|
if XForwardedFor != nil {
|
||||||
ips := strings.Split(string(XForwardedFor), ",")
|
ips := strings.Split(string(XForwardedFor), ",")
|
||||||
|
|
||||||
|
@ -230,26 +252,24 @@ func (c *AutheliaCtx) RemoteIP() net.IP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.RequestCtx.RemoteIP()
|
return ctx.RequestCtx.RemoteIP()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOriginalURL extract the URL from the request headers (X-Original-URI or X-Forwarded-* headers).
|
// GetOriginalURL extract the URL from the request headers (X-Original-URL or X-Forwarded-* headers).
|
||||||
func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
|
func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
|
||||||
originalURL := c.XOriginalURL()
|
originalURL := ctx.XOriginalURL()
|
||||||
if originalURL != nil {
|
if originalURL != nil {
|
||||||
parsedURL, err := url.ParseRequestURI(string(originalURL))
|
parsedURL, err := url.ParseRequestURI(string(originalURL))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err)
|
return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Logger.Trace("Using X-Original-URL header content as targeted site URL")
|
ctx.Logger.Trace("Using X-Original-URL header content as targeted site URL")
|
||||||
|
|
||||||
return parsedURL, nil
|
return parsedURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardedProto := c.XForwardedProto()
|
forwardedProto, forwardedHost, forwardedURI := ctx.XForwardedProto(), ctx.XForwardedHost(), ctx.XForwardedURI()
|
||||||
forwardedHost := c.XForwardedHost()
|
|
||||||
forwardedURI := c.XForwardedURI()
|
|
||||||
|
|
||||||
if forwardedProto == nil {
|
if forwardedProto == nil {
|
||||||
return nil, errMissingXForwardedProto
|
return nil, errMissingXForwardedProto
|
||||||
|
@ -271,22 +291,22 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
|
||||||
return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err)
|
return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " +
|
ctx.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " +
|
||||||
"to construct targeted site URL")
|
"to construct targeted site URL")
|
||||||
|
|
||||||
return parsedURL, nil
|
return parsedURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsXHR returns true if the request is a XMLHttpRequest.
|
// IsXHR returns true if the request is a XMLHttpRequest.
|
||||||
func (c AutheliaCtx) IsXHR() (xhr bool) {
|
func (ctx AutheliaCtx) IsXHR() (xhr bool) {
|
||||||
requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith)
|
requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
|
||||||
|
|
||||||
return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR
|
return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
|
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
|
||||||
func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
||||||
accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",")
|
accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
|
||||||
|
|
||||||
for i, accept := range accepts {
|
for i, accept := range accepts {
|
||||||
mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ")
|
mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ")
|
||||||
|
@ -300,22 +320,22 @@ func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
|
||||||
|
|
||||||
// SpecialRedirect performs a redirect similar to fasthttp.RequestCtx except it allows statusCode 401 and includes body
|
// SpecialRedirect performs a redirect similar to fasthttp.RequestCtx except it allows statusCode 401 and includes body
|
||||||
// content in the form of a link to the location.
|
// content in the form of a link to the location.
|
||||||
func (c *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
|
func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
|
||||||
if statusCode < fasthttp.StatusMovedPermanently || (statusCode > fasthttp.StatusSeeOther && statusCode != fasthttp.StatusTemporaryRedirect && statusCode != fasthttp.StatusPermanentRedirect && statusCode != fasthttp.StatusUnauthorized) {
|
if statusCode < fasthttp.StatusMovedPermanently || (statusCode > fasthttp.StatusSeeOther && statusCode != fasthttp.StatusTemporaryRedirect && statusCode != fasthttp.StatusPermanentRedirect && statusCode != fasthttp.StatusUnauthorized) {
|
||||||
statusCode = fasthttp.StatusFound
|
statusCode = fasthttp.StatusFound
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetContentType(contentTypeTextHTML)
|
ctx.SetContentType(contentTypeTextHTML)
|
||||||
c.SetStatusCode(statusCode)
|
ctx.SetStatusCode(statusCode)
|
||||||
|
|
||||||
u := fasthttp.AcquireURI()
|
u := fasthttp.AcquireURI()
|
||||||
|
|
||||||
c.URI().CopyTo(u)
|
ctx.URI().CopyTo(u)
|
||||||
u.Update(uri)
|
u.Update(uri)
|
||||||
|
|
||||||
c.Response.Header.SetBytesV("Location", u.FullURI())
|
ctx.Response.Header.SetBytesV("Location", u.FullURI())
|
||||||
|
|
||||||
c.SetBodyString(fmt.Sprintf("<a href=\"%s\">%s</a>", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
|
ctx.SetBodyString(fmt.Sprintf("<a href=\"%s\">%s</a>", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
|
||||||
|
|
||||||
fasthttp.ReleaseURI(u)
|
fasthttp.ReleaseURI(u)
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithoutURI(t *testing.T) {
|
||||||
originalURL, err := mock.Ctx.GetOriginalURL()
|
originalURL, err := mock.Ctx.GetOriginalURL()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expectedURL, err := url.ParseRequestURI("https://home.example.com")
|
expectedURL, err := url.ParseRequestURI("https://home.example.com/")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedURL, originalURL)
|
assert.Equal(t, expectedURL, originalURL)
|
||||||
}
|
}
|
||||||
|
@ -70,3 +70,48 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithURI(t *testing.T) {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, "Unable to parse URL extracted from X-Original-URL header: parse \"htt-ps//home?-.example.com\": invalid URI for request", err.Error())
|
assert.Equal(t, "Unable to parse URL extracted from X-Original-URL header: parse \"htt-ps//home?-.example.com\": invalid URI for request", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShouldFallbackToNonXForwardedHeaders(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password")
|
||||||
|
mock.Ctx.RequestCtx.Request.SetHost("auth.example.com:1234")
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("http"), mock.Ctx.XForwardedProto())
|
||||||
|
assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost())
|
||||||
|
assert.Equal(t, []byte("/2fa/one-time-password"), mock.Ctx.XForwardedURI())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldOnlyFallbackToNonXForwardedHeadersWhenNil(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password")
|
||||||
|
mock.Ctx.RequestCtx.Request.SetHost("localhost")
|
||||||
|
mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXForwardedHost, "auth.example.com:1234")
|
||||||
|
mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-URI", "/base/2fa/one-time-password")
|
||||||
|
mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Method", "GET")
|
||||||
|
|
||||||
|
assert.Equal(t, []byte("https"), mock.Ctx.XForwardedProto())
|
||||||
|
assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost())
|
||||||
|
assert.Equal(t, []byte("/base/2fa/one-time-password"), mock.Ctx.XForwardedURI())
|
||||||
|
assert.Equal(t, []byte("GET"), mock.Ctx.XForwardedMethod())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDetectXHR(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXRequestedWith, "XMLHttpRequest")
|
||||||
|
|
||||||
|
assert.True(t, mock.Ctx.IsXHR())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDetectNonXHR(t *testing.T) {
|
||||||
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
|
defer mock.Close()
|
||||||
|
|
||||||
|
assert.False(t, mock.Ctx.IsXHR())
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,12 @@ var (
|
||||||
headerXForwardedURI = []byte("X-Forwarded-URI")
|
headerXForwardedURI = []byte("X-Forwarded-URI")
|
||||||
headerXOriginalURL = []byte("X-Original-URL")
|
headerXOriginalURL = []byte("X-Original-URL")
|
||||||
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
||||||
|
|
||||||
|
protoHTTPS = []byte("https")
|
||||||
|
protoHTTP = []byte("http")
|
||||||
|
|
||||||
|
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
||||||
|
UserValueKeyBaseURL = []byte("base_url")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -90,24 +90,6 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
|
||||||
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
|
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
|
|
||||||
mock := mocks.NewMockAutheliaCtx(t)
|
|
||||||
defer mock.Close()
|
|
||||||
|
|
||||||
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
|
||||||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
|
||||||
|
|
||||||
mock.StorageMock.EXPECT().
|
|
||||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
|
||||||
Return(nil)
|
|
||||||
|
|
||||||
args := newArgs(defaultRetriever)
|
|
||||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
|
||||||
assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
||||||
mock := mocks.NewMockAutheliaCtx(t)
|
mock := mocks.NewMockAutheliaCtx(t)
|
||||||
defer mock.Close()
|
defer mock.Close()
|
||||||
|
|
|
@ -12,7 +12,7 @@ func StripPathMiddleware(path string, next fasthttp.RequestHandler) fasthttp.Req
|
||||||
uri := ctx.RequestURI()
|
uri := ctx.RequestURI()
|
||||||
|
|
||||||
if strings.HasPrefix(string(uri), path) {
|
if strings.HasPrefix(string(uri), path) {
|
||||||
ctx.SetUserValue("base_url", path)
|
ctx.SetUserValueBytes(UserValueKeyBaseURL, path)
|
||||||
|
|
||||||
newURI := strings.TrimPrefix(string(uri), path)
|
newURI := strings.TrimPrefix(string(uri), path)
|
||||||
ctx.Request.SetRequestURI(newURI)
|
ctx.Request.SetRequestURI(newURI)
|
||||||
|
|
|
@ -46,11 +46,11 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
||||||
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, configuration.Session.Name, configuration.Theme, https)
|
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, configuration.Session.Name, configuration.Theme, https)
|
||||||
|
|
||||||
r := router.New()
|
r := router.New()
|
||||||
r.GET("/", serveIndexHandler)
|
r.GET("/", autheliaMiddleware(serveIndexHandler))
|
||||||
r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
|
r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
|
||||||
|
|
||||||
r.GET("/api/", serveSwaggerHandler)
|
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
|
||||||
r.GET("/api/"+apiFile, serveSwaggerAPIHandler)
|
r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
|
||||||
|
|
||||||
for _, f := range rootFiles {
|
for _, f := range rootFiles {
|
||||||
r.GET("/"+f, middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, embeddedFS))
|
r.GET("/"+f, middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, embeddedFS))
|
||||||
|
@ -148,7 +148,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
||||||
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
|
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.NotFound = serveIndexHandler
|
r.NotFound = autheliaMiddleware(serveIndexHandler)
|
||||||
|
|
||||||
handler := middlewares.LogRequestMiddleware(r.Handler)
|
handler := middlewares.LogRequestMiddleware(r.Handler)
|
||||||
if configuration.Server.Path != "" {
|
if configuration.Server.Path != "" {
|
||||||
|
|
|
@ -7,16 +7,15 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ServeTemplatedFile serves a templated version of a specified file,
|
// ServeTemplatedFile serves a templated version of a specified file,
|
||||||
// this is utilised to pass information between the backend and frontend
|
// this is utilised to pass information between the backend and frontend
|
||||||
// and generate a nonce to support a restrictive CSP while using material-ui.
|
// and generate a nonce to support a restrictive CSP while using material-ui.
|
||||||
func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) fasthttp.RequestHandler {
|
func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) middlewares.RequestHandler {
|
||||||
logger := logging.Logger()
|
logger := logging.Logger()
|
||||||
|
|
||||||
a, err := assets.Open(publicDir + file)
|
a, err := assets.Open(publicDir + file)
|
||||||
|
@ -34,9 +33,9 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
|
||||||
logger.Fatalf("Unable to parse %s template: %s", file, err)
|
logger.Fatalf("Unable to parse %s template: %s", file, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
return func(ctx *middlewares.AutheliaCtx) {
|
||||||
base := ""
|
base := ""
|
||||||
if baseURL := ctx.UserValue("base_url"); baseURL != nil {
|
if baseURL := ctx.UserValueBytes(middlewares.UserValueKeyBaseURL); baseURL != nil {
|
||||||
base = baseURL.(string)
|
base = baseURL.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,16 +50,16 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
|
||||||
var scheme = "https"
|
var scheme = "https"
|
||||||
|
|
||||||
if !https {
|
if !https {
|
||||||
proto := string(ctx.Request.Header.Peek(fasthttp.HeaderXForwardedProto))
|
proto := string(ctx.XForwardedProto())
|
||||||
switch proto {
|
switch proto {
|
||||||
case "":
|
case "":
|
||||||
scheme = "http"
|
break
|
||||||
default:
|
case "http", "https":
|
||||||
scheme = proto
|
scheme = proto
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := scheme + "://" + string(ctx.Request.Host()) + base + "/"
|
baseURL := scheme + "://" + string(ctx.XForwardedHost()) + base + "/"
|
||||||
nonce := utils.RandomString(32, utils.AlphaNumericCharacters, true)
|
nonce := utils.RandomString(32, utils.AlphaNumericCharacters, true)
|
||||||
|
|
||||||
switch extension := filepath.Ext(file); extension {
|
switch extension := filepath.Ext(file); extension {
|
||||||
|
@ -81,7 +80,7 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
|
||||||
|
|
||||||
err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, BaseURL, CSPNonce, DuoSelfEnrollment, LogoOverride, RememberMe, ResetPassword, Session, Theme string }{Base: base, BaseURL: baseURL, CSPNonce: nonce, DuoSelfEnrollment: duoSelfEnrollment, LogoOverride: logoOverride, RememberMe: rememberMe, ResetPassword: resetPassword, Session: session, Theme: theme})
|
err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, BaseURL, CSPNonce, DuoSelfEnrollment, LogoOverride, RememberMe, ResetPassword, Session, Theme string }{Base: base, BaseURL: baseURL, CSPNonce: nonce, DuoSelfEnrollment: duoSelfEnrollment, LogoOverride: logoOverride, RememberMe: rememberMe, ResetPassword: resetPassword, Session: session, Theme: theme})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Error("an error occurred", 503)
|
ctx.RequestCtx.Error("an error occurred", 503)
|
||||||
logger.Errorf("Unable to execute template: %v", err)
|
logger.Errorf("Unable to execute template: %v", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue