refactor(middlewares): factorize responses (#3628)

pull/3544/head^2
James Elliott 2022-07-08 22:18:52 +10:00 committed by GitHub
parent f0084cb711
commit ce779b2533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 73 additions and 57 deletions

View File

@ -29,7 +29,7 @@ var (
)
var (
headerContentTypeValueDefault = []byte("text/plain; charset=utf-8")
headerContentTypeValueTextPlain = []byte("text/plain; charset=utf-8")
)
const (

View File

@ -1,8 +1,6 @@
package handlers
import (
"encoding/json"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/middlewares"
@ -18,19 +16,19 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
ctx.Response.SetStatusCode(fasthttp.StatusBadRequest)
ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer)
ctx.SetContentType("application/json")
if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil {
if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
// TODO: Determine if this is the appropriate error code here.
ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.ReplyStatusCode(fasthttp.StatusInternalServerError)
return
}
@ -46,19 +44,19 @@ func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
ctx.Response.SetStatusCode(fasthttp.StatusBadRequest)
ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer)
ctx.SetContentType("application/json")
if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil {
if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
// TODO: Determine if this is the appropriate error code here.
ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.ReplyStatusCode(fasthttp.StatusInternalServerError)
return
}

View File

@ -52,7 +52,7 @@ func (s *StateGetSuite) TestShouldReturnUsernameFromSession() {
err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody)
require.NoError(s.T(), err)
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), expectedBody, actualBody)
}
@ -82,7 +82,7 @@ func (s *StateGetSuite) TestShouldReturnAuthenticationLevelFromSession() {
err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody)
require.NoError(s.T(), err)
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), expectedBody, actualBody)
}

View File

@ -351,7 +351,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingNoHeader()
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@ -367,7 +367,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingEmptyHeader
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@ -387,7 +387,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongPasswo
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@ -403,7 +403,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongHeader
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@ -721,7 +721,7 @@ func TestShouldRedirectWhenSessionInactiveForTooLongAndRDParamProvided(t *testin
mock.Ctx.Request.Header.Set("Accept", "text/html; charset=utf-8")
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=GET\">Found</a>",
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=GET\">302 Found</a>",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 302, mock.Ctx.Response.StatusCode())
@ -741,7 +741,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=GET\">Found</a>",
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=GET\">302 Found</a>",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 302, mock.Ctx.Response.StatusCode())
@ -752,7 +752,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=POST\">See Other</a>",
assert.Equal(t, "<a href=\"https://login.example.com/?rd=https%3A%2F%2Ftwo-factor.example.com&amp;rm=POST\">303 See Other</a>",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 303, mock.Ctx.Response.StatusCode())
}
@ -809,7 +809,7 @@ func TestShouldURLEncodeRedirectionURLParameter(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, "<a href=\"https://auth.mydomain.com/?rd=https%3A%2F%2Ftwo-factor.example.com\">Found</a>",
assert.Equal(t, "<a href=\"https://auth.mydomain.com/?rd=https%3A%2F%2Ftwo-factor.example.com\">302 Found</a>",
string(mock.Ctx.Response.Body()))
}
@ -1240,7 +1240,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())
assert.Equal(t, "Unauthorized", string(mock.Ctx.Response.Body()))
assert.Equal(t, "401 Unauthorized", string(mock.Ctx.Response.Body()))
}
func TestGetProfileRefreshSettings(t *testing.T) {

View File

@ -250,7 +250,7 @@ func respondUnauthorized(ctx *middlewares.AutheliaCtx, message string) {
// *fasthttp.RequestCtx or *middlewares.AutheliaCtx.
func SetStatusCodeResponse(ctx *fasthttp.RequestCtx, statusCode int) {
ctx.Response.Reset()
ctx.SetContentTypeBytes(headerContentTypeValueDefault)
ctx.SetContentTypeBytes(headerContentTypeValueTextPlain)
ctx.SetStatusCode(statusCode)
ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode)))
}

View File

@ -68,14 +68,9 @@ func (ctx *AutheliaCtx) Error(err error, message string) {
// SetJSONError sets the body of the response to an JSON error KO message.
func (ctx *AutheliaCtx) SetJSONError(message string) {
b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
if marshalErr != nil {
ctx.Logger.Error(marshalErr)
if replyErr := ctx.ReplyJSON(ErrorResponse{Status: "KO", Message: message}, 0); replyErr != nil {
ctx.Logger.Error(replyErr)
}
ctx.SetContentType(contentTypeApplicationJSON)
ctx.SetBody(b)
}
// ReplyError reply with an error but does not display any stack trace in the logs.
@ -86,24 +81,52 @@ func (ctx *AutheliaCtx) ReplyError(err error, message string) {
ctx.Logger.Error(marshalErr)
}
ctx.SetContentType(contentTypeApplicationJSON)
ctx.SetContentTypeBytes(contentTypeApplicationJSON)
ctx.SetBody(b)
ctx.Logger.Debug(err)
}
// ReplyStatusCode resets a response and replies with the given status code and relevant message.
func (ctx *AutheliaCtx) ReplyStatusCode(statusCode int) {
ctx.Response.Reset()
ctx.SetStatusCode(statusCode)
ctx.SetContentTypeBytes(contentTypeTextPlain)
ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode)))
}
// ReplyJSON writes a JSON response.
func (ctx *AutheliaCtx) ReplyJSON(data interface{}, statusCode int) (err error) {
var (
body []byte
)
if body, err = json.Marshal(data); err != nil {
return fmt.Errorf("unable to marshal JSON body: %w", err)
}
if statusCode > 0 {
ctx.SetStatusCode(statusCode)
}
ctx.SetContentTypeBytes(contentTypeApplicationJSON)
ctx.SetBody(body)
return nil
}
// ReplyUnauthorized response sent when user is unauthorized.
func (ctx *AutheliaCtx) ReplyUnauthorized() {
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
ctx.ReplyStatusCode(fasthttp.StatusUnauthorized)
}
// ReplyForbidden response sent when access is forbidden to user.
func (ctx *AutheliaCtx) ReplyForbidden() {
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
ctx.ReplyStatusCode(fasthttp.StatusForbidden)
}
// ReplyBadRequest response sent when bad request has been sent.
func (ctx *AutheliaCtx) ReplyBadRequest() {
ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
}
// XForwardedProto return the content of the X-Forwarded-Proto header.
@ -208,7 +231,7 @@ func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error {
// ReplyOK is a helper method to reply ok.
func (ctx *AutheliaCtx) ReplyOK() {
ctx.SetContentType(contentTypeApplicationJSON)
ctx.SetContentTypeBytes(contentTypeApplicationJSON)
ctx.SetBody(okMessageBytes)
}
@ -235,15 +258,7 @@ func (ctx *AutheliaCtx) ParseBody(value interface{}) error {
// SetJSONBody Set json body.
func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error {
b, err := json.Marshal(OKResponse{Status: "OK", Data: value})
if err != nil {
return fmt.Errorf("unable to marshal JSON body: %w", err)
}
ctx.SetContentType(contentTypeApplicationJSON)
ctx.SetBody(b)
return nil
return ctx.ReplyJSON(OKResponse{Status: "OK", Data: value}, 0)
}
// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
@ -329,7 +344,7 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
statusCode = fasthttp.StatusFound
}
ctx.SetContentType(contentTypeTextHTML)
ctx.SetContentTypeBytes(contentTypeTextHTML)
ctx.SetStatusCode(statusCode)
u := fasthttp.AcquireURI()
@ -337,9 +352,9 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
ctx.URI().CopyTo(u)
u.Update(uri)
ctx.Response.Header.SetBytesV("Location", u.FullURI())
ctx.Response.Header.SetBytesKV(headerLocation, u.FullURI())
ctx.SetBodyString(fmt.Sprintf("<a href=\"%s\">%s</a>", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
ctx.SetBodyString(fmt.Sprintf("<a href=\"%s\">%d %s</a>", utils.StringHTMLEscape(string(u.FullURI())), statusCode, fasthttp.StatusMessage(statusCode)))
fasthttp.ReleaseURI(u)
}

View File

@ -9,6 +9,7 @@ import (
var (
headerAccept = []byte(fasthttp.HeaderAccept)
headerContentLength = []byte(fasthttp.HeaderContentLength)
headerLocation = []byte(fasthttp.HeaderLocation)
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
@ -69,12 +70,14 @@ var (
UserValueKeyBaseURL = []byte("base_url")
headerSeparator = []byte(", ")
contentTypeTextPlain = []byte("text/plain; charset=utf-8")
contentTypeTextHTML = []byte("text/html; charset=utf-8")
contentTypeApplicationJSON = []byte("application/json; charset=utf-8")
)
const (
headerValueXRequestedWithXHR = "XMLHttpRequest"
contentTypeApplicationJSON = "application/json"
contentTypeTextHTML = "text/html"
)
var okMessageBytes = []byte("{\"status\":\"OK\"}")

View File

@ -193,7 +193,7 @@ func (s *StandaloneSuite) TestShouldRespectMethodsACL() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">302 Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
req.Header.Set("X-Forwarded-Method", "OPTIONS")
@ -219,7 +219,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">302 Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
req.Header.Set("X-Forwarded-Method", "POST")
@ -230,7 +230,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() {
s.Assert().NoError(err)
urlEncodedAdminURL = url.QueryEscape(SecureBaseURL + "/")
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">See Other</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">303 See Other</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
// Standard case using nginx.
@ -247,7 +247,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyUnauthorized() {
s.Assert().Equal(res.StatusCode, 401)
body, err := io.ReadAll(res.Body)
s.Assert().NoError(err)
s.Assert().Equal("Unauthorized", string(body))
s.Assert().Equal("401 Unauthorized", string(body))
}
// Standard case using Kubernetes.
@ -266,7 +266,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalURL() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(AdminBaseURL)
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">302 Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI() {
@ -285,7 +285,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI(
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
s.Assert().Equal(fmt.Sprintf("<a href=\"%s\">302 Found</a>", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
func (s *StandaloneSuite) TestShouldRecordMetrics() {