refactor(handlers): utilize referer for auth logging rm/rd (#2655)

This utilizes the referrer query parameters instead of current request query parameters for logging the requested URI and method. Minor performance improvements to header peek/sets.
pull/2658/head
James Elliott 2021-12-02 13:21:46 +11:00 committed by GitHub
parent f3f3b31b12
commit bf9ab360bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 56 deletions

View File

@ -1,5 +1,9 @@
package handlers package handlers
import (
"github.com/valyala/fasthttp"
)
const ( const (
// ActionTOTPRegistration is the string representation of the action for which the token has been produced. // ActionTOTPRegistration is the string representation of the action for which the token has been produced.
ActionTOTPRegistration = "RegisterTOTPDevice" ActionTOTPRegistration = "RegisterTOTPDevice"
@ -11,20 +15,15 @@ const (
ActionResetPassword = "ResetPassword" ActionResetPassword = "ResetPassword"
) )
const ( var (
// HeaderProxyAuthorization is the basic-auth HTTP header Authelia utilises. headerAuthorization = []byte(fasthttp.HeaderAuthorization)
HeaderProxyAuthorization = "Proxy-Authorization" headerProxyAuthorization = []byte(fasthttp.HeaderProxyAuthorization)
// HeaderAuthorization is the basic-auth HTTP header Authelia utilises with "auth=basic" query param. headerSessionUsername = []byte("Session-Username")
HeaderAuthorization = "Authorization" headerRemoteUser = []byte("Remote-User")
headerRemoteGroups = []byte("Remote-Groups")
// HeaderSessionUsername is used as additional protection to validate a user for things like pam_exec. headerRemoteName = []byte("Remote-Name")
HeaderSessionUsername = "Session-Username" headerRemoteEmail = []byte("Remote-Email")
headerRemoteUser = "Remote-User"
headerRemoteName = "Remote-Name"
headerRemoteEmail = "Remote-Email"
headerRemoteGroups = "Remote-Groups"
) )
const ( const (

View File

@ -33,7 +33,7 @@ func isSchemeWSS(url *url.URL) bool {
// parseBasicAuth parses an HTTP Basic Authentication string. // parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
func parseBasicAuth(header, auth string) (username, password string, err error) { func parseBasicAuth(header []byte, auth string) (username, password string, err error) {
if !strings.HasPrefix(auth, authPrefix) { if !strings.HasPrefix(auth, authPrefix) {
return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), header) return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), header)
} }
@ -85,7 +85,7 @@ func isTargetURLAuthorized(authorizer *authorization.Authorizer, targetURL url.U
// verifyBasicAuth verify that the provided username and password are correct and // verifyBasicAuth verify that the provided username and password are correct and
// that the user is authorized to target the resource. // that the user is authorized to target the resource.
func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) (username, name string, groups, emails []string, authLevel authentication.Level, err error) { func verifyBasicAuth(ctx *middlewares.AutheliaCtx, header, auth []byte) (username, name string, groups, emails []string, authLevel authentication.Level, err error) {
username, password, err := parseBasicAuth(header, string(auth)) username, password, err := parseBasicAuth(header, string(auth))
if err != nil { if err != nil {
@ -116,14 +116,14 @@ func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) (
// setForwardedHeaders set the forwarded User, Groups, Name and Email headers. // setForwardedHeaders set the forwarded User, Groups, Name and Email headers.
func setForwardedHeaders(headers *fasthttp.ResponseHeader, username, name string, groups, emails []string) { func setForwardedHeaders(headers *fasthttp.ResponseHeader, username, name string, groups, emails []string) {
if username != "" { if username != "" {
headers.Set(headerRemoteUser, username) headers.SetBytesK(headerRemoteUser, username)
headers.Set(headerRemoteGroups, strings.Join(groups, ",")) headers.SetBytesK(headerRemoteGroups, strings.Join(groups, ","))
headers.Set(headerRemoteName, name) headers.SetBytesK(headerRemoteName, name)
if emails != nil { if emails != nil {
headers.Set(headerRemoteEmail, emails[0]) headers.SetBytesK(headerRemoteEmail, emails[0])
} else { } else {
headers.Set(headerRemoteEmail, "") headers.SetBytesK(headerRemoteEmail, "")
} }
} }
} }
@ -403,13 +403,13 @@ func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (r
} }
func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile bool, refreshProfileInterval time.Duration) (isBasicAuth bool, username, name string, groups, emails []string, authLevel authentication.Level, err error) { func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile bool, refreshProfileInterval time.Duration) (isBasicAuth bool, username, name string, groups, emails []string, authLevel authentication.Level, err error) {
authHeader := HeaderProxyAuthorization authHeader := headerProxyAuthorization
if bytes.Equal(ctx.QueryArgs().Peek("auth"), []byte("basic")) { if bytes.Equal(ctx.QueryArgs().Peek("auth"), []byte("basic")) {
authHeader = HeaderAuthorization authHeader = headerAuthorization
isBasicAuth = true isBasicAuth = true
} }
authValue := ctx.Request.Header.Peek(authHeader) authValue := ctx.Request.Header.PeekBytes(authHeader)
if authValue != nil { if authValue != nil {
isBasicAuth = true isBasicAuth = true
} else if isBasicAuth { } else if isBasicAuth {
@ -418,23 +418,23 @@ func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile
} }
if isBasicAuth { if isBasicAuth {
username, name, groups, emails, authLevel, err = verifyBasicAuth(authHeader, authValue, ctx) username, name, groups, emails, authLevel, err = verifyBasicAuth(ctx, authHeader, authValue)
return return
} }
userSession := ctx.GetSession() userSession := ctx.GetSession()
username, name, groups, emails, authLevel, err = verifySessionCookie(ctx, targetURL, &userSession, refreshProfile, refreshProfileInterval) username, name, groups, emails, authLevel, err = verifySessionCookie(ctx, targetURL, &userSession, refreshProfile, refreshProfileInterval)
sessionUsername := ctx.Request.Header.Peek(HeaderSessionUsername) sessionUsername := ctx.Request.Header.PeekBytes(headerSessionUsername)
if sessionUsername != nil && !strings.EqualFold(string(sessionUsername), username) { if sessionUsername != nil && !strings.EqualFold(string(sessionUsername), username) {
ctx.Logger.Warnf("Possible cookie hijack or attempt to bypass security detected destroying the session and sending 401 response") ctx.Logger.Warnf("Possible cookie hijack or attempt to bypass security detected destroying the session and sending 401 response")
err = ctx.Providers.SessionProvider.DestroySession(ctx.RequestCtx) err = ctx.Providers.SessionProvider.DestroySession(ctx.RequestCtx)
if err != nil { if err != nil {
ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", HeaderSessionUsername, err) ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", headerSessionUsername, err)
} }
err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, HeaderSessionUsername, sessionUsername, targetURL.String()) err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, headerSessionUsername, sessionUsername, targetURL.String())
} }
return return

View File

@ -85,34 +85,34 @@ func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) {
// Test parseBasicAuth. // Test parseBasicAuth.
func TestShouldRaiseWhenHeaderDoesNotContainBasicPrefix(t *testing.T) { func TestShouldRaiseWhenHeaderDoesNotContainBasicPrefix(t *testing.T) {
_, _, err := parseBasicAuth(HeaderProxyAuthorization, "alzefzlfzemjfej==") _, _, err := parseBasicAuth(headerProxyAuthorization, "alzefzlfzemjfej==")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "Basic prefix not found in Proxy-Authorization header", err.Error()) assert.Equal(t, "Basic prefix not found in Proxy-Authorization header", err.Error())
} }
func TestShouldRaiseWhenCredentialsAreNotInBase64(t *testing.T) { func TestShouldRaiseWhenCredentialsAreNotInBase64(t *testing.T) {
_, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic alzefzlfzemjfej==") _, _, err := parseBasicAuth(headerProxyAuthorization, "Basic alzefzlfzemjfej==")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "illegal base64 data at input byte 16", err.Error()) assert.Equal(t, "illegal base64 data at input byte 16", err.Error())
} }
func TestShouldRaiseWhenCredentialsAreNotInCorrectForm(t *testing.T) { func TestShouldRaiseWhenCredentialsAreNotInCorrectForm(t *testing.T) {
// The decoded format should be user:password. // The decoded format should be user:password.
_, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9obiBwYXNzd29yZA==") _, _, err := parseBasicAuth(headerProxyAuthorization, "Basic am9obiBwYXNzd29yZA==")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "format of Proxy-Authorization header must be user:password", err.Error()) assert.Equal(t, "format of Proxy-Authorization header must be user:password", err.Error())
} }
func TestShouldUseProvidedHeaderName(t *testing.T) { func TestShouldUseProvidedHeaderName(t *testing.T) {
// The decoded format should be user:password. // The decoded format should be user:password.
_, _, err := parseBasicAuth("HeaderName", "") _, _, err := parseBasicAuth([]byte("HeaderName"), "")
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "Basic prefix not found in HeaderName header", err.Error()) assert.Equal(t, "Basic prefix not found in HeaderName header", err.Error())
} }
func TestShouldReturnUsernameAndPassword(t *testing.T) { func TestShouldReturnUsernameAndPassword(t *testing.T) {
// the decoded format should be user:password. // the decoded format should be user:password.
user, password, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9objpwYXNzd29yZA==") user, password, err := parseBasicAuth(headerProxyAuthorization, "Basic am9objpwYXNzd29yZA==")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "john", user) assert.Equal(t, "john", user)
assert.Equal(t, "password", password) assert.Equal(t, "password", password)
@ -176,7 +176,7 @@ func TestShouldVerifyWrongCredentials(t *testing.T) {
CheckUserPassword(gomock.Eq("john"), gomock.Eq("password")). CheckUserPassword(gomock.Eq("john"), gomock.Eq("password")).
Return(false, nil) Return(false, nil)
_, _, _, _, _, err := verifyBasicAuth(HeaderProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA=="), mock.Ctx) _, _, _, _, _, err := verifyBasicAuth(mock.Ctx, headerProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA=="))
assert.Error(t, err) assert.Error(t, err)
} }
@ -1211,7 +1211,7 @@ func TestShouldCheckValidSessionUsernameHeaderAndReturn200(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com") mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com")
mock.Ctx.Request.Header.Set(HeaderSessionUsername, testUsername) mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, testUsername)
VerifyGet(verifyGetCfg)(mock.Ctx) VerifyGet(verifyGetCfg)(mock.Ctx)
assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())
@ -1235,7 +1235,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com") mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com")
mock.Ctx.Request.Header.Set(HeaderSessionUsername, "root") mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, "root")
VerifyGet(verifyGetCfg)(mock.Ctx) VerifyGet(verifyGetCfg)(mock.Ctx)
assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())

View File

@ -150,7 +150,20 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba
// We only Mark if there was no underlying error. // We only Mark if there was no underlying error.
ctx.Logger.Debugf("Mark %s authentication attempt made by user '%s'", authType, username) ctx.Logger.Debugf("Mark %s authentication attempt made by user '%s'", authType, username)
if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, string(ctx.RequestCtx.QueryArgs().Peek("rd")), string(ctx.RequestCtx.QueryArgs().Peek("rm")), authType, ctx.RemoteIP()); err != nil { var (
requestURI, requestMethod string
)
referer := ctx.Request.Header.Referer()
if referer != nil {
refererURL, err := url.Parse(string(referer))
if err == nil {
requestURI = refererURL.Query().Get("rd")
requestMethod = refererURL.Query().Get("rm")
}
}
if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, requestURI, requestMethod, authType, ctx.RemoteIP()); err != nil {
ctx.Logger.Errorf("Unable to mark %s authentication attempt by user '%s': %+v", authType, username, err) ctx.Logger.Errorf("Unable to mark %s authentication attempt by user '%s': %+v", authType, username, err)
return err return err

View File

@ -102,22 +102,22 @@ func (c *AutheliaCtx) ReplyBadRequest() {
// XForwardedProto return the content of the X-Forwarded-Proto header. // XForwardedProto return the content of the X-Forwarded-Proto header.
func (c *AutheliaCtx) XForwardedProto() []byte { func (c *AutheliaCtx) XForwardedProto() []byte {
return c.RequestCtx.Request.Header.Peek(headerXForwardedProto) return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
} }
// XForwardedMethod return the content of the X-Forwarded-Method header. // XForwardedMethod return the content of the X-Forwarded-Method header.
func (c *AutheliaCtx) XForwardedMethod() []byte { func (c *AutheliaCtx) XForwardedMethod() []byte {
return c.RequestCtx.Request.Header.Peek(headerXForwardedMethod) return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
} }
// XForwardedHost return the content of the X-Forwarded-Host header. // XForwardedHost return the content of the X-Forwarded-Host header.
func (c *AutheliaCtx) XForwardedHost() []byte { func (c *AutheliaCtx) XForwardedHost() []byte {
return c.RequestCtx.Request.Header.Peek(headerXForwardedHost) return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
} }
// XForwardedURI return the content of the X-Forwarded-URI header. // XForwardedURI return the content of the X-Forwarded-URI header.
func (c *AutheliaCtx) XForwardedURI() []byte { func (c *AutheliaCtx) XForwardedURI() []byte {
return c.RequestCtx.Request.Header.Peek(headerXForwardedURI) return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
} }
// BasePath returns the base_url as per the path visited by the client. // BasePath returns the base_url as per the path visited by the client.
@ -159,7 +159,7 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) {
// XOriginalURL return the content of the X-Original-URL header. // XOriginalURL return the content of the X-Original-URL header.
func (c *AutheliaCtx) XOriginalURL() []byte { func (c *AutheliaCtx) XOriginalURL() []byte {
return c.RequestCtx.Request.Header.Peek(headerXOriginalURL) return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
} }
// GetSession return the user session. Any update will be saved in cache. // GetSession return the user session. Any update will be saved in cache.
@ -220,7 +220,7 @@ func (c *AutheliaCtx) SetJSONBody(value interface{}) error {
// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided. // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
func (c *AutheliaCtx) RemoteIP() net.IP { func (c *AutheliaCtx) RemoteIP() net.IP {
XForwardedFor := c.Request.Header.Peek("X-Forwarded-For") XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor)
if XForwardedFor != nil { if XForwardedFor != nil {
ips := strings.Split(string(XForwardedFor), ",") ips := strings.Split(string(XForwardedFor), ",")
@ -278,14 +278,14 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
// IsXHR returns true if the request is a XMLHttpRequest. // IsXHR returns true if the request is a XMLHttpRequest.
func (c AutheliaCtx) IsXHR() (xhr bool) { func (c AutheliaCtx) IsXHR() (xhr bool) {
requestedWith := c.Request.Header.Peek(headerXRequestedWith) requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith)
return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR
} }
// AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type. // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) { func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
accepts := strings.Split(string(c.Request.Header.Peek("Accept")), ",") accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",")
for i, accept := range accepts { for i, accept := range accepts {
mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ") mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ")

View File

@ -1,19 +1,23 @@
package middlewares package middlewares
const ( import (
headerXForwardedProto = "X-Forwarded-Proto" "github.com/valyala/fasthttp"
headerXForwardedMethod = "X-Forwarded-Method" )
headerXForwardedHost = "X-Forwarded-Host"
headerXForwardedURI = "X-Forwarded-URI" var (
headerXOriginalURL = "X-Original-URL" headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
headerXRequestedWith = "X-Requested-With" headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
headerAccept = []byte(fasthttp.HeaderAccept)
headerXForwardedURI = []byte("X-Forwarded-URI")
headerXOriginalURL = []byte("X-Original-URL")
headerXForwardedMethod = []byte("X-Forwarded-Method")
) )
const ( const (
headerValueXRequestedWithXHR = "XMLHttpRequest" headerValueXRequestedWithXHR = "XMLHttpRequest"
)
const (
contentTypeApplicationJSON = "application/json" contentTypeApplicationJSON = "application/json"
contentTypeTextHTML = "text/html" contentTypeTextHTML = "text/html"
) )