diff --git a/internal/handlers/const.go b/internal/handlers/const.go index fe292e72c..d3bc4e986 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -1,5 +1,9 @@ package handlers +import ( + "github.com/valyala/fasthttp" +) + const ( // ActionTOTPRegistration is the string representation of the action for which the token has been produced. ActionTOTPRegistration = "RegisterTOTPDevice" @@ -11,20 +15,15 @@ const ( ActionResetPassword = "ResetPassword" ) -const ( - // HeaderProxyAuthorization is the basic-auth HTTP header Authelia utilises. - HeaderProxyAuthorization = "Proxy-Authorization" +var ( + headerAuthorization = []byte(fasthttp.HeaderAuthorization) + headerProxyAuthorization = []byte(fasthttp.HeaderProxyAuthorization) - // HeaderAuthorization is the basic-auth HTTP header Authelia utilises with "auth=basic" query param. - HeaderAuthorization = "Authorization" - - // HeaderSessionUsername is used as additional protection to validate a user for things like pam_exec. - HeaderSessionUsername = "Session-Username" - - headerRemoteUser = "Remote-User" - headerRemoteName = "Remote-Name" - headerRemoteEmail = "Remote-Email" - headerRemoteGroups = "Remote-Groups" + headerSessionUsername = []byte("Session-Username") + headerRemoteUser = []byte("Remote-User") + headerRemoteGroups = []byte("Remote-Groups") + headerRemoteName = []byte("Remote-Name") + headerRemoteEmail = []byte("Remote-Email") ) const ( diff --git a/internal/handlers/handler_verify.go b/internal/handlers/handler_verify.go index 6fce44741..ba3f71f6b 100644 --- a/internal/handlers/handler_verify.go +++ b/internal/handlers/handler_verify.go @@ -33,7 +33,7 @@ func isSchemeWSS(url *url.URL) bool { // parseBasicAuth parses an HTTP Basic Authentication string. // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). -func parseBasicAuth(header, auth string) (username, password string, err error) { +func parseBasicAuth(header []byte, auth string) (username, password string, err error) { if !strings.HasPrefix(auth, authPrefix) { return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), header) } @@ -85,7 +85,7 @@ func isTargetURLAuthorized(authorizer *authorization.Authorizer, targetURL url.U // verifyBasicAuth verify that the provided username and password are correct and // that the user is authorized to target the resource. -func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) (username, name string, groups, emails []string, authLevel authentication.Level, err error) { +func verifyBasicAuth(ctx *middlewares.AutheliaCtx, header, auth []byte) (username, name string, groups, emails []string, authLevel authentication.Level, err error) { username, password, err := parseBasicAuth(header, string(auth)) if err != nil { @@ -116,14 +116,14 @@ func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) ( // setForwardedHeaders set the forwarded User, Groups, Name and Email headers. func setForwardedHeaders(headers *fasthttp.ResponseHeader, username, name string, groups, emails []string) { if username != "" { - headers.Set(headerRemoteUser, username) - headers.Set(headerRemoteGroups, strings.Join(groups, ",")) - headers.Set(headerRemoteName, name) + headers.SetBytesK(headerRemoteUser, username) + headers.SetBytesK(headerRemoteGroups, strings.Join(groups, ",")) + headers.SetBytesK(headerRemoteName, name) if emails != nil { - headers.Set(headerRemoteEmail, emails[0]) + headers.SetBytesK(headerRemoteEmail, emails[0]) } else { - headers.Set(headerRemoteEmail, "") + headers.SetBytesK(headerRemoteEmail, "") } } } @@ -403,13 +403,13 @@ func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (r } func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile bool, refreshProfileInterval time.Duration) (isBasicAuth bool, username, name string, groups, emails []string, authLevel authentication.Level, err error) { - authHeader := HeaderProxyAuthorization + authHeader := headerProxyAuthorization if bytes.Equal(ctx.QueryArgs().Peek("auth"), []byte("basic")) { - authHeader = HeaderAuthorization + authHeader = headerAuthorization isBasicAuth = true } - authValue := ctx.Request.Header.Peek(authHeader) + authValue := ctx.Request.Header.PeekBytes(authHeader) if authValue != nil { isBasicAuth = true } else if isBasicAuth { @@ -418,23 +418,23 @@ func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile } if isBasicAuth { - username, name, groups, emails, authLevel, err = verifyBasicAuth(authHeader, authValue, ctx) + username, name, groups, emails, authLevel, err = verifyBasicAuth(ctx, authHeader, authValue) return } userSession := ctx.GetSession() username, name, groups, emails, authLevel, err = verifySessionCookie(ctx, targetURL, &userSession, refreshProfile, refreshProfileInterval) - sessionUsername := ctx.Request.Header.Peek(HeaderSessionUsername) + sessionUsername := ctx.Request.Header.PeekBytes(headerSessionUsername) if sessionUsername != nil && !strings.EqualFold(string(sessionUsername), username) { ctx.Logger.Warnf("Possible cookie hijack or attempt to bypass security detected destroying the session and sending 401 response") err = ctx.Providers.SessionProvider.DestroySession(ctx.RequestCtx) if err != nil { - ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", HeaderSessionUsername, err) + ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", headerSessionUsername, err) } - err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, HeaderSessionUsername, sessionUsername, targetURL.String()) + err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, headerSessionUsername, sessionUsername, targetURL.String()) } return diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index cf1faa17f..bb78ff022 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -85,34 +85,34 @@ func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) { // Test parseBasicAuth. func TestShouldRaiseWhenHeaderDoesNotContainBasicPrefix(t *testing.T) { - _, _, err := parseBasicAuth(HeaderProxyAuthorization, "alzefzlfzemjfej==") + _, _, err := parseBasicAuth(headerProxyAuthorization, "alzefzlfzemjfej==") assert.Error(t, err) assert.Equal(t, "Basic prefix not found in Proxy-Authorization header", err.Error()) } func TestShouldRaiseWhenCredentialsAreNotInBase64(t *testing.T) { - _, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic alzefzlfzemjfej==") + _, _, err := parseBasicAuth(headerProxyAuthorization, "Basic alzefzlfzemjfej==") assert.Error(t, err) assert.Equal(t, "illegal base64 data at input byte 16", err.Error()) } func TestShouldRaiseWhenCredentialsAreNotInCorrectForm(t *testing.T) { // The decoded format should be user:password. - _, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9obiBwYXNzd29yZA==") + _, _, err := parseBasicAuth(headerProxyAuthorization, "Basic am9obiBwYXNzd29yZA==") assert.Error(t, err) assert.Equal(t, "format of Proxy-Authorization header must be user:password", err.Error()) } func TestShouldUseProvidedHeaderName(t *testing.T) { // The decoded format should be user:password. - _, _, err := parseBasicAuth("HeaderName", "") + _, _, err := parseBasicAuth([]byte("HeaderName"), "") assert.Error(t, err) assert.Equal(t, "Basic prefix not found in HeaderName header", err.Error()) } func TestShouldReturnUsernameAndPassword(t *testing.T) { // the decoded format should be user:password. - user, password, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9objpwYXNzd29yZA==") + user, password, err := parseBasicAuth(headerProxyAuthorization, "Basic am9objpwYXNzd29yZA==") assert.NoError(t, err) assert.Equal(t, "john", user) assert.Equal(t, "password", password) @@ -176,7 +176,7 @@ func TestShouldVerifyWrongCredentials(t *testing.T) { CheckUserPassword(gomock.Eq("john"), gomock.Eq("password")). Return(false, nil) - _, _, _, _, _, err := verifyBasicAuth(HeaderProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA=="), mock.Ctx) + _, _, _, _, _, err := verifyBasicAuth(mock.Ctx, headerProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA==")) assert.Error(t, err) } @@ -1211,7 +1211,7 @@ func TestShouldCheckValidSessionUsernameHeaderAndReturn200(t *testing.T) { require.NoError(t, err) mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com") - mock.Ctx.Request.Header.Set(HeaderSessionUsername, testUsername) + mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, testUsername) VerifyGet(verifyGetCfg)(mock.Ctx) assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) @@ -1235,7 +1235,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) { require.NoError(t, err) mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com") - mock.Ctx.Request.Header.Set(HeaderSessionUsername, "root") + mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, "root") VerifyGet(verifyGetCfg)(mock.Ctx) assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) diff --git a/internal/handlers/response.go b/internal/handlers/response.go index 34a26860f..1f3f5220e 100644 --- a/internal/handlers/response.go +++ b/internal/handlers/response.go @@ -150,7 +150,20 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba // We only Mark if there was no underlying error. ctx.Logger.Debugf("Mark %s authentication attempt made by user '%s'", authType, username) - if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, string(ctx.RequestCtx.QueryArgs().Peek("rd")), string(ctx.RequestCtx.QueryArgs().Peek("rm")), authType, ctx.RemoteIP()); err != nil { + var ( + requestURI, requestMethod string + ) + + referer := ctx.Request.Header.Referer() + if referer != nil { + refererURL, err := url.Parse(string(referer)) + if err == nil { + requestURI = refererURL.Query().Get("rd") + requestMethod = refererURL.Query().Get("rm") + } + } + + if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, requestURI, requestMethod, authType, ctx.RemoteIP()); err != nil { ctx.Logger.Errorf("Unable to mark %s authentication attempt by user '%s': %+v", authType, username, err) return err diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 4a1242cab..1552cd3d7 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -102,22 +102,22 @@ func (c *AutheliaCtx) ReplyBadRequest() { // XForwardedProto return the content of the X-Forwarded-Proto header. func (c *AutheliaCtx) XForwardedProto() []byte { - return c.RequestCtx.Request.Header.Peek(headerXForwardedProto) + return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto) } // XForwardedMethod return the content of the X-Forwarded-Method header. func (c *AutheliaCtx) XForwardedMethod() []byte { - return c.RequestCtx.Request.Header.Peek(headerXForwardedMethod) + return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod) } // XForwardedHost return the content of the X-Forwarded-Host header. func (c *AutheliaCtx) XForwardedHost() []byte { - return c.RequestCtx.Request.Header.Peek(headerXForwardedHost) + return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost) } // XForwardedURI return the content of the X-Forwarded-URI header. func (c *AutheliaCtx) XForwardedURI() []byte { - return c.RequestCtx.Request.Header.Peek(headerXForwardedURI) + return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI) } // BasePath returns the base_url as per the path visited by the client. @@ -159,7 +159,7 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) { // XOriginalURL return the content of the X-Original-URL header. func (c *AutheliaCtx) XOriginalURL() []byte { - return c.RequestCtx.Request.Header.Peek(headerXOriginalURL) + return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL) } // GetSession return the user session. Any update will be saved in cache. @@ -220,7 +220,7 @@ func (c *AutheliaCtx) SetJSONBody(value interface{}) error { // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided. func (c *AutheliaCtx) RemoteIP() net.IP { - XForwardedFor := c.Request.Header.Peek("X-Forwarded-For") + XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor) if XForwardedFor != nil { ips := strings.Split(string(XForwardedFor), ",") @@ -278,14 +278,14 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) { // IsXHR returns true if the request is a XMLHttpRequest. func (c AutheliaCtx) IsXHR() (xhr bool) { - requestedWith := c.Request.Header.Peek(headerXRequestedWith) + requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith) return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR } // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type. func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { - accepts := strings.Split(string(c.Request.Header.Peek("Accept")), ",") + accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",") for i, accept := range accepts { mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ") diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index 23de4427b..01fc3f1e1 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -1,21 +1,25 @@ package middlewares -const ( - headerXForwardedProto = "X-Forwarded-Proto" - headerXForwardedMethod = "X-Forwarded-Method" - headerXForwardedHost = "X-Forwarded-Host" - headerXForwardedURI = "X-Forwarded-URI" - headerXOriginalURL = "X-Original-URL" - headerXRequestedWith = "X-Requested-With" +import ( + "github.com/valyala/fasthttp" +) + +var ( + headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto) + headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost) + headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor) + headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith) + headerAccept = []byte(fasthttp.HeaderAccept) + + headerXForwardedURI = []byte("X-Forwarded-URI") + headerXOriginalURL = []byte("X-Original-URL") + headerXForwardedMethod = []byte("X-Forwarded-Method") ) const ( headerValueXRequestedWithXHR = "XMLHttpRequest" -) - -const ( - contentTypeApplicationJSON = "application/json" - contentTypeTextHTML = "text/html" + contentTypeApplicationJSON = "application/json" + contentTypeTextHTML = "text/html" ) var okMessageBytes = []byte("{\"status\":\"OK\"}")