358 lines
9.5 KiB
Go
358 lines
9.5 KiB
Go
|
package middlewares
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"time"
|
||
|
|
||
|
"github.com/valyala/fasthttp"
|
||
|
|
||
|
"github.com/authelia/authelia/v4/internal/authentication"
|
||
|
"github.com/authelia/authelia/v4/internal/session"
|
||
|
)
|
||
|
|
||
|
type ProtectionBuilder struct {
|
||
|
escalation *OTPEscalationProtectedEndpointConfig
|
||
|
level *RequiredLevelProtectedEndpointConfig
|
||
|
}
|
||
|
|
||
|
type Protection struct {
|
||
|
level authentication.Level
|
||
|
|
||
|
escalationSkip2FA bool
|
||
|
}
|
||
|
|
||
|
func (p *Protection) handler(ctx *AutheliaCtx, userSession *session.UserSession) (level, escalation bool) {
|
||
|
return p.handleLevel(ctx, userSession), p.handleEscalation(ctx, userSession)
|
||
|
}
|
||
|
|
||
|
func (p *Protection) handleLevel(ctx *AutheliaCtx, userSession *session.UserSession) (level bool) {
|
||
|
if p.level == authentication.NotAuthenticated {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
func (p *Protection) handleEscalation(ctx *AutheliaCtx, userSession *session.UserSession) (escalation bool) {
|
||
|
if p.escalationSkip2FA && userSession.AuthenticationLevel >= authentication.TwoFactor {
|
||
|
ctx.Logger.
|
||
|
WithField("username", userSession.Username).
|
||
|
Warning("User elevated session check has skipped due to 2FA")
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
if userSession.Elevations.User == nil {
|
||
|
ctx.Logger.
|
||
|
WithField("username", userSession.Username).
|
||
|
Warning("User session elevation has not been created")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if userSession.Elevations.User.Expires.Before(ctx.Clock.Now()) {
|
||
|
ctx.Logger.
|
||
|
WithField("username", userSession.Username).
|
||
|
WithField("expires", userSession.Elevations.User.Expires).
|
||
|
Debug("User session elevation has expired")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if !ctx.RemoteIP().Equal(userSession.Elevations.User.RemoteIP) {
|
||
|
ctx.Logger.
|
||
|
WithField("username", userSession.Username).
|
||
|
WithField("elevation_ip", userSession.Elevations.User.RemoteIP).
|
||
|
Warning("User session elevation IP did not match the request")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (p *Protection) Handler(ctx *AutheliaCtx) {
|
||
|
userSession, err := ctx.GetSession()
|
||
|
|
||
|
if err != nil || userSession.IsAnonymous() {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
level, escalation := p.handler(ctx, &userSession)
|
||
|
|
||
|
}
|
||
|
|
||
|
func (p *Protection) Middleware(next RequestHandler) RequestHandler {
|
||
|
return func(ctx *AutheliaCtx) {
|
||
|
userSession, err := ctx.GetSession()
|
||
|
|
||
|
if err != nil || userSession.IsAnonymous() {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
level, escalation := p.handler(ctx, &userSession)
|
||
|
|
||
|
if level && escalation {
|
||
|
next(ctx)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
type ProtectionEscalation struct {
|
||
|
}
|
||
|
|
||
|
// OTPEscalationProtectedEndpointConfig represents how the Escalation middleware behaves.
|
||
|
type OTPEscalationProtectedEndpointConfig struct {
|
||
|
Characters int
|
||
|
EmailValidityDuration time.Duration
|
||
|
EscalationValidityDuration time.Duration
|
||
|
Skip2FA bool
|
||
|
}
|
||
|
|
||
|
type RequiredLevelProtectedEndpointConfig struct {
|
||
|
Level authentication.Level
|
||
|
}
|
||
|
|
||
|
type ProtectedEndpointConfig struct {
|
||
|
OTPEscalation *OTPEscalationProtectedEndpointConfig
|
||
|
RequiredLevel *RequiredLevelProtectedEndpointConfig
|
||
|
}
|
||
|
|
||
|
func NewProtectedEndpoint(config *ProtectedEndpointConfig) AutheliaMiddleware {
|
||
|
return ProtectedEndpoint(NewProtectedEndpointHandlers(config)...)
|
||
|
}
|
||
|
|
||
|
func NewProtectedEndpointHandlers(config *ProtectedEndpointConfig) (handlers []ProtectedEndpointHandler) {
|
||
|
if config.RequiredLevel != nil {
|
||
|
handlers = append(handlers, &RequiredLevelProtectedEndpointHandler{level: config.RequiredLevel.Level})
|
||
|
}
|
||
|
|
||
|
if config.OTPEscalation != nil {
|
||
|
handlers = append(handlers, &OTPEscalationProtectedEndpointHandler{config: config.OTPEscalation})
|
||
|
}
|
||
|
|
||
|
return handlers
|
||
|
}
|
||
|
|
||
|
func ProtectedEndpoint(handlers ...ProtectedEndpointHandler) AutheliaMiddleware {
|
||
|
n := len(handlers)
|
||
|
|
||
|
return func(next RequestHandler) RequestHandler {
|
||
|
return func(ctx *AutheliaCtx) {
|
||
|
s, err := ctx.GetSession()
|
||
|
|
||
|
if err != nil || s.IsAnonymous() {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
failed, failedAuthentication, failedElevation := doCheckProtectionHandlers(ctx, &s, n, handlers)
|
||
|
|
||
|
if failed {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusForbidden, fasthttp.StatusMessage(fasthttp.StatusForbidden), failedAuthentication, failedElevation)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
next(ctx)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func ProtectedEndpointStatus(handlers ...ProtectedEndpointHandler) RequestHandler {
|
||
|
n := len(handlers)
|
||
|
|
||
|
return func(ctx *AutheliaCtx) {
|
||
|
s, err := ctx.GetSession()
|
||
|
|
||
|
if err != nil || s.IsAnonymous() {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusUnauthorized, fasthttp.StatusMessage(fasthttp.StatusUnauthorized), false, false)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
_, failedAuthentication, failedElevation := doCheckProtectionHandlers(ctx, &s, n, handlers)
|
||
|
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusOK, "", failedAuthentication, failedElevation)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func doCheckProtectionHandlers(ctx *AutheliaCtx, s *session.UserSession, n int, handlers []ProtectedEndpointHandler) (failed, authentication, elevation bool) {
|
||
|
for i := 0; i < n; i++ {
|
||
|
if handlers[i].Check(ctx, s) {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
failed = true
|
||
|
|
||
|
if handlers[i].IsAuthentication() {
|
||
|
authentication = true
|
||
|
}
|
||
|
|
||
|
if handlers[i].IsElevation() {
|
||
|
elevation = true
|
||
|
}
|
||
|
|
||
|
handlers[i].Failure(ctx, s)
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type ProtectedEndpointHandler interface {
|
||
|
Name() string
|
||
|
Check(ctx *AutheliaCtx, s *session.UserSession) (success bool)
|
||
|
Failure(ctx *AutheliaCtx, s *session.UserSession)
|
||
|
|
||
|
IsAuthentication() bool
|
||
|
IsElevation() bool
|
||
|
}
|
||
|
|
||
|
func NewRequiredLevelProtectedEndpointHandler(level authentication.Level, statusCode int) *RequiredLevelProtectedEndpointHandler {
|
||
|
handler := &RequiredLevelProtectedEndpointHandler{
|
||
|
level: level,
|
||
|
statusCode: statusCode,
|
||
|
}
|
||
|
|
||
|
if handler.statusCode == 0 {
|
||
|
handler.statusCode = fasthttp.StatusForbidden
|
||
|
}
|
||
|
|
||
|
if handler.level == 0 {
|
||
|
handler.level = authentication.OneFactor
|
||
|
}
|
||
|
|
||
|
return handler
|
||
|
}
|
||
|
|
||
|
type RequiredLevelProtectedEndpointHandler struct {
|
||
|
level authentication.Level
|
||
|
statusCode int
|
||
|
}
|
||
|
|
||
|
func (h *RequiredLevelProtectedEndpointHandler) Name() string {
|
||
|
return fmt.Sprintf("required_level(%s)", h.level)
|
||
|
}
|
||
|
|
||
|
func (h *RequiredLevelProtectedEndpointHandler) IsAuthentication() bool {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (h *RequiredLevelProtectedEndpointHandler) IsElevation() bool {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (h *RequiredLevelProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||
|
return s.AuthenticationLevel >= h.level
|
||
|
}
|
||
|
|
||
|
func (h *RequiredLevelProtectedEndpointHandler) Failure(_ *AutheliaCtx, _ *session.UserSession) {
|
||
|
}
|
||
|
|
||
|
func NewOTPEscalationProtectedEndpointHandler(config OTPEscalationProtectedEndpointConfig) *OTPEscalationProtectedEndpointHandler {
|
||
|
return &OTPEscalationProtectedEndpointHandler{
|
||
|
config: &config,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type OTPEscalationProtectedEndpointHandler struct {
|
||
|
config *OTPEscalationProtectedEndpointConfig
|
||
|
}
|
||
|
|
||
|
func (h *OTPEscalationProtectedEndpointHandler) Name() string {
|
||
|
return "one_time_password"
|
||
|
}
|
||
|
|
||
|
func (h *OTPEscalationProtectedEndpointHandler) IsAuthentication() bool {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (h *OTPEscalationProtectedEndpointHandler) IsElevation() bool {
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (h *OTPEscalationProtectedEndpointHandler) Check(ctx *AutheliaCtx, s *session.UserSession) (success bool) {
|
||
|
if h.config.Skip2FA && s.AuthenticationLevel >= authentication.TwoFactor {
|
||
|
ctx.Logger.
|
||
|
WithField("username", s.Username).
|
||
|
Warning("User elevated session check has skipped due to 2FA")
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
if s.Elevations.User == nil {
|
||
|
ctx.Logger.
|
||
|
WithField("username", s.Username).
|
||
|
Warning("User elevated session has not been created")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if s.Elevations.User.Expires.Before(ctx.Clock.Now()) {
|
||
|
ctx.Logger.
|
||
|
WithField("username", s.Username).
|
||
|
WithField("expires", s.Elevations.User.Expires).
|
||
|
Debug("User session elevation has expired")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
if !ctx.RemoteIP().Equal(s.Elevations.User.RemoteIP) {
|
||
|
ctx.Logger.
|
||
|
WithField("username", s.Username).
|
||
|
WithField("elevation_ip", s.Elevations.User.RemoteIP).
|
||
|
Warning("User session elevation IP did not match the request")
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (h *OTPEscalationProtectedEndpointHandler) Failure(ctx *AutheliaCtx, s *session.UserSession) {
|
||
|
if s.Elevations.User != nil {
|
||
|
// If we make it here we should destroy the elevation data.
|
||
|
s.Elevations.User = nil
|
||
|
|
||
|
if err := ctx.SaveSession(*s); err != nil {
|
||
|
ctx.Logger.WithError(err).Error("Error session after user elevated session failure")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Require1FA requires the user to have authenticated with at least one-factor authentication (i.e. password).
|
||
|
func Require1FA(next RequestHandler) RequestHandler {
|
||
|
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.OneFactor, fasthttp.StatusForbidden))
|
||
|
|
||
|
return handler(next)
|
||
|
}
|
||
|
|
||
|
// Require2FA requires the user to have authenticated with two-factor authentication.
|
||
|
func Require2FA(next RequestHandler) RequestHandler {
|
||
|
handler := ProtectedEndpoint(NewRequiredLevelProtectedEndpointHandler(authentication.TwoFactor, fasthttp.StatusForbidden))
|
||
|
|
||
|
return handler(next)
|
||
|
}
|
||
|
|
||
|
// Require2FAWithAPIResponse requires the user to have authenticated with two-factor authentication.
|
||
|
func Require2FAWithAPIResponse(next RequestHandler) RequestHandler {
|
||
|
return func(ctx *AutheliaCtx) {
|
||
|
s, err := ctx.GetSession()
|
||
|
|
||
|
if err != nil || s.AuthenticationLevel < authentication.TwoFactor {
|
||
|
ctx.SetAuthenticationResponseJSON(fasthttp.StatusForbidden, "Authentication Required.", true, false)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
next(ctx)
|
||
|
}
|
||
|
}
|