fix(authentication): utilize msad password history control (#3256)

This fixes an issue where the Microsoft Active Directory Server Policy Hints control was not being used to prevent avoidance of the PSO / FGPP applicable to the user.
pull/3336/head
James Elliott 2022-05-10 14:38:36 +10:00 committed by GitHub
parent 3178e88c58
commit 150e54c3ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 3318 additions and 391 deletions

View File

@ -18,7 +18,46 @@ const (
const ( const (
ldapSupportedExtensionAttribute = "supportedExtension" ldapSupportedExtensionAttribute = "supportedExtension"
ldapOIDPasswdModifyExtension = "1.3.6.1.4.1.4203.1.11.1" // http://oidref.com/1.3.6.1.4.1.4203.1.11.1
// LDAP Extension OID: Password Modify Extended Operation.
//
// RFC3062: https://datatracker.ietf.org/doc/html/rfc3062
//
// OID Reference: http://oidref.com/1.3.6.1.4.1.4203.1.11.1
//
// See the linked documents for more information.
ldapOIDExtensionPwdModifyExOp = "1.3.6.1.4.1.4203.1.11.1"
// LDAP Extension OID: Transport Layer Security.
//
// RFC2830: https://datatracker.ietf.org/doc/html/rfc2830
//
// OID Reference: https://oidref.com/1.3.6.1.4.1.1466.20037
//
// See the linked documents for more information.
ldapOIDExtensionTLS = "1.3.6.1.4.1.1466.20037"
)
const (
ldapSupportedControlAttribute = "supportedControl"
// LDAP Control OID: Microsoft Password Policy Hints.
//
// MS ADTS: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/4add7bce-e502-4e0f-9d69-1a3f153713e2
//
// OID Reference: https://oidref.com/1.2.840.113556.1.4.2239
//
// See the linked documents for more information.
ldapOIDControlMsftServerPolicyHints = "1.2.840.113556.1.4.2239"
// LDAP Control OID: Microsoft Password Policy Hints (deprecated).
//
// MS ADTS: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/49751d58-8115-4277-8faf-64c83a5f658f
//
// OID Reference: https://oidref.com/1.2.840.113556.1.4.2066
//
// See the linked documents for more information.
ldapOIDControlMsftServerPolicyHintsDeprecated = "1.2.840.113556.1.4.2066"
) )
const ( const (
@ -32,6 +71,10 @@ const (
ldapPlaceholderUsername = "{username}" ldapPlaceholderUsername = "{username}"
) )
const (
none = "none"
)
// CryptAlgo the crypt representation of an algorithm used in the prefix of the hash. // CryptAlgo the crypt representation of an algorithm used in the prefix of the hash.
type CryptAlgo string type CryptAlgo string

View File

@ -3,5 +3,5 @@ package authentication
// This file is used to generate mocks. You can generate all mocks using the // This file is used to generate mocks. You can generate all mocks using the
// command `go generate github.com/authelia/authelia/v4/internal/authentication`. // command `go generate github.com/authelia/authelia/v4/internal/authentication`.
//go:generate mockgen -package authentication -destination ldap_connection_mock.go -mock_names LDAPConnection=MockLDAPConnection github.com/authelia/authelia/v4/internal/authentication LDAPConnection //go:generate mockgen -package authentication -destination ldap_client_mock.go -mock_names LDAPClient=MockLDAPClient github.com/authelia/authelia/v4/internal/authentication LDAPClient
//go:generate mockgen -package authentication -destination ldap_connection_factory_mock.go -mock_names LDAPConnectionFactory=MockLDAPConnectionFactory github.com/authelia/authelia/v4/internal/authentication LDAPConnectionFactory //go:generate mockgen -package authentication -destination ldap_client_factory_mock.go -mock_names LDAPClientFactory=MockLDAPClientFactory github.com/authelia/authelia/v4/internal/authentication LDAPClientFactory

View File

@ -0,0 +1,18 @@
package authentication
import (
"github.com/go-ldap/ldap/v3"
)
// ProductionLDAPClientFactory the production implementation of an ldap connection factory.
type ProductionLDAPClientFactory struct{}
// NewProductionLDAPClientFactory create a concrete ldap connection factory.
func NewProductionLDAPClientFactory() *ProductionLDAPClientFactory {
return &ProductionLDAPClientFactory{}
}
// DialURL creates a client from an LDAP URL when successful.
func (f *ProductionLDAPClientFactory) DialURL(addr string, opts ...ldap.DialOpt) (client LDAPClient, err error) {
return ldap.DialURL(addr, opts...)
}

View File

@ -0,0 +1,55 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/authelia/authelia/v4/internal/authentication (interfaces: LDAPClientFactory)
// Package authentication is a generated GoMock package.
package authentication
import (
reflect "reflect"
v3 "github.com/go-ldap/ldap/v3"
gomock "github.com/golang/mock/gomock"
)
// MockLDAPClientFactory is a mock of LDAPClientFactory interface.
type MockLDAPClientFactory struct {
ctrl *gomock.Controller
recorder *MockLDAPClientFactoryMockRecorder
}
// MockLDAPClientFactoryMockRecorder is the mock recorder for MockLDAPClientFactory.
type MockLDAPClientFactoryMockRecorder struct {
mock *MockLDAPClientFactory
}
// NewMockLDAPClientFactory creates a new mock instance.
func NewMockLDAPClientFactory(ctrl *gomock.Controller) *MockLDAPClientFactory {
mock := &MockLDAPClientFactory{ctrl: ctrl}
mock.recorder = &MockLDAPClientFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLDAPClientFactory) EXPECT() *MockLDAPClientFactoryMockRecorder {
return m.recorder
}
// DialURL mocks base method.
func (m *MockLDAPClientFactory) DialURL(arg0 string, arg1 ...v3.DialOpt) (LDAPClient, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DialURL", varargs...)
ret0, _ := ret[0].(LDAPClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialURL indicates an expected call of DialURL.
func (mr *MockLDAPClientFactoryMockRecorder) DialURL(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialURL", reflect.TypeOf((*MockLDAPClientFactory)(nil).DialURL), varargs...)
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/authelia/authelia/v4/internal/authentication (interfaces: LDAPConnection) // Source: github.com/authelia/authelia/v4/internal/authentication (interfaces: LDAPClient)
// Package authentication is a generated GoMock package. // Package authentication is a generated GoMock package.
package authentication package authentication
@ -12,31 +12,31 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockLDAPConnection is a mock of LDAPConnection interface. // MockLDAPClient is a mock of LDAPClient interface.
type MockLDAPConnection struct { type MockLDAPClient struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockLDAPConnectionMockRecorder recorder *MockLDAPClientMockRecorder
} }
// MockLDAPConnectionMockRecorder is the mock recorder for MockLDAPConnection. // MockLDAPClientMockRecorder is the mock recorder for MockLDAPClient.
type MockLDAPConnectionMockRecorder struct { type MockLDAPClientMockRecorder struct {
mock *MockLDAPConnection mock *MockLDAPClient
} }
// NewMockLDAPConnection creates a new mock instance. // NewMockLDAPClient creates a new mock instance.
func NewMockLDAPConnection(ctrl *gomock.Controller) *MockLDAPConnection { func NewMockLDAPClient(ctrl *gomock.Controller) *MockLDAPClient {
mock := &MockLDAPConnection{ctrl: ctrl} mock := &MockLDAPClient{ctrl: ctrl}
mock.recorder = &MockLDAPConnectionMockRecorder{mock} mock.recorder = &MockLDAPClientMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLDAPConnection) EXPECT() *MockLDAPConnectionMockRecorder { func (m *MockLDAPClient) EXPECT() *MockLDAPClientMockRecorder {
return m.recorder return m.recorder
} }
// Bind mocks base method. // Bind mocks base method.
func (m *MockLDAPConnection) Bind(arg0, arg1 string) error { func (m *MockLDAPClient) Bind(arg0, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Bind", arg0, arg1) ret := m.ctrl.Call(m, "Bind", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -44,25 +44,25 @@ func (m *MockLDAPConnection) Bind(arg0, arg1 string) error {
} }
// Bind indicates an expected call of Bind. // Bind indicates an expected call of Bind.
func (mr *MockLDAPConnectionMockRecorder) Bind(arg0, arg1 interface{}) *gomock.Call { func (mr *MockLDAPClientMockRecorder) Bind(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockLDAPConnection)(nil).Bind), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockLDAPClient)(nil).Bind), arg0, arg1)
} }
// Close mocks base method. // Close mocks base method.
func (m *MockLDAPConnection) Close() { func (m *MockLDAPClient) Close() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Close") m.ctrl.Call(m, "Close")
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close.
func (mr *MockLDAPConnectionMockRecorder) Close() *gomock.Call { func (mr *MockLDAPClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLDAPConnection)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLDAPClient)(nil).Close))
} }
// Modify mocks base method. // Modify mocks base method.
func (m *MockLDAPConnection) Modify(arg0 *ldap.ModifyRequest) error { func (m *MockLDAPClient) Modify(arg0 *ldap.ModifyRequest) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Modify", arg0) ret := m.ctrl.Call(m, "Modify", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -70,13 +70,13 @@ func (m *MockLDAPConnection) Modify(arg0 *ldap.ModifyRequest) error {
} }
// Modify indicates an expected call of Modify. // Modify indicates an expected call of Modify.
func (mr *MockLDAPConnectionMockRecorder) Modify(arg0 interface{}) *gomock.Call { func (mr *MockLDAPClientMockRecorder) Modify(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Modify", reflect.TypeOf((*MockLDAPConnection)(nil).Modify), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Modify", reflect.TypeOf((*MockLDAPClient)(nil).Modify), arg0)
} }
// PasswordModify mocks base method. // PasswordModify mocks base method.
func (m *MockLDAPConnection) PasswordModify(arg0 *ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) { func (m *MockLDAPClient) PasswordModify(arg0 *ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PasswordModify", arg0) ret := m.ctrl.Call(m, "PasswordModify", arg0)
ret0, _ := ret[0].(*ldap.PasswordModifyResult) ret0, _ := ret[0].(*ldap.PasswordModifyResult)
@ -85,13 +85,13 @@ func (m *MockLDAPConnection) PasswordModify(arg0 *ldap.PasswordModifyRequest) (*
} }
// PasswordModify indicates an expected call of PasswordModify. // PasswordModify indicates an expected call of PasswordModify.
func (mr *MockLDAPConnectionMockRecorder) PasswordModify(arg0 interface{}) *gomock.Call { func (mr *MockLDAPClientMockRecorder) PasswordModify(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordModify", reflect.TypeOf((*MockLDAPConnection)(nil).PasswordModify), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordModify", reflect.TypeOf((*MockLDAPClient)(nil).PasswordModify), arg0)
} }
// Search mocks base method. // Search mocks base method.
func (m *MockLDAPConnection) Search(arg0 *ldap.SearchRequest) (*ldap.SearchResult, error) { func (m *MockLDAPClient) Search(arg0 *ldap.SearchRequest) (*ldap.SearchResult, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Search", arg0) ret := m.ctrl.Call(m, "Search", arg0)
ret0, _ := ret[0].(*ldap.SearchResult) ret0, _ := ret[0].(*ldap.SearchResult)
@ -100,13 +100,13 @@ func (m *MockLDAPConnection) Search(arg0 *ldap.SearchRequest) (*ldap.SearchResul
} }
// Search indicates an expected call of Search. // Search indicates an expected call of Search.
func (mr *MockLDAPConnectionMockRecorder) Search(arg0 interface{}) *gomock.Call { func (mr *MockLDAPClientMockRecorder) Search(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockLDAPConnection)(nil).Search), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockLDAPClient)(nil).Search), arg0)
} }
// StartTLS mocks base method. // StartTLS mocks base method.
func (m *MockLDAPConnection) StartTLS(arg0 *tls.Config) error { func (m *MockLDAPClient) StartTLS(arg0 *tls.Config) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartTLS", arg0) ret := m.ctrl.Call(m, "StartTLS", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -114,7 +114,7 @@ func (m *MockLDAPConnection) StartTLS(arg0 *tls.Config) error {
} }
// StartTLS indicates an expected call of StartTLS. // StartTLS indicates an expected call of StartTLS.
func (mr *MockLDAPConnectionMockRecorder) StartTLS(arg0 interface{}) *gomock.Call { func (mr *MockLDAPClientMockRecorder) StartTLS(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTLS", reflect.TypeOf((*MockLDAPConnection)(nil).StartTLS), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartTLS", reflect.TypeOf((*MockLDAPClient)(nil).StartTLS), arg0)
} }

View File

@ -1,18 +0,0 @@
package authentication
import (
"github.com/go-ldap/ldap/v3"
)
// ProductionLDAPConnectionFactory the production implementation of an ldap connection factory.
type ProductionLDAPConnectionFactory struct{}
// NewProductionLDAPConnectionFactory create a concrete ldap connection factory.
func NewProductionLDAPConnectionFactory() *ProductionLDAPConnectionFactory {
return &ProductionLDAPConnectionFactory{}
}
// DialURL creates a connection from an LDAP URL when successful.
func (f *ProductionLDAPConnectionFactory) DialURL(addr string, opts ...ldap.DialOpt) (conn LDAPConnection, err error) {
return ldap.DialURL(addr, opts...)
}

View File

@ -1,55 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/authelia/authelia/v4/internal/authentication (interfaces: LDAPConnectionFactory)
// Package authentication is a generated GoMock package.
package authentication
import (
reflect "reflect"
v3 "github.com/go-ldap/ldap/v3"
gomock "github.com/golang/mock/gomock"
)
// MockLDAPConnectionFactory is a mock of LDAPConnectionFactory interface.
type MockLDAPConnectionFactory struct {
ctrl *gomock.Controller
recorder *MockLDAPConnectionFactoryMockRecorder
}
// MockLDAPConnectionFactoryMockRecorder is the mock recorder for MockLDAPConnectionFactory.
type MockLDAPConnectionFactoryMockRecorder struct {
mock *MockLDAPConnectionFactory
}
// NewMockLDAPConnectionFactory creates a new mock instance.
func NewMockLDAPConnectionFactory(ctrl *gomock.Controller) *MockLDAPConnectionFactory {
mock := &MockLDAPConnectionFactory{ctrl: ctrl}
mock.recorder = &MockLDAPConnectionFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLDAPConnectionFactory) EXPECT() *MockLDAPConnectionFactoryMockRecorder {
return m.recorder
}
// DialURL mocks base method.
func (m *MockLDAPConnectionFactory) DialURL(arg0 string, arg1 ...v3.DialOpt) (LDAPConnection, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "DialURL", varargs...)
ret0, _ := ret[0].(LDAPConnection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialURL indicates an expected call of DialURL.
func (mr *MockLDAPConnectionFactoryMockRecorder) DialURL(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialURL", reflect.TypeOf((*MockLDAPConnectionFactory)(nil).DialURL), varargs...)
}

View File

@ -0,0 +1,36 @@
package authentication
import (
ber "github.com/go-asn1-ber/asn1-ber"
)
type controlMsftServerPolicyHints struct {
oid string
}
// GetControlType implements ldap.Control.
func (c *controlMsftServerPolicyHints) GetControlType() string {
return c.oid
}
// Encode implements ldap.Control.
func (c *controlMsftServerPolicyHints) Encode() (packet *ber.Packet) {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PolicyHintsRequestValue")
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 1, "Flags"))
controlValue := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Policy Hints)")
controlValue.AppendChild(seq)
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.GetControlType(), "Control Type (LDAP_SERVER_POLICY_HINTS_OID)"))
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, true, "Criticality"))
packet.AppendChild(controlValue)
return packet
}
// String implements ldap.Control.
func (c *controlMsftServerPolicyHints) String() string {
return "Enforce the password history length constraint (MS-SAMR section 3.1.1.7.1) during password set: " + c.GetControlType()
}

View File

@ -21,12 +21,12 @@ type LDAPUserProvider struct {
tlsConfig *tls.Config tlsConfig *tls.Config
dialOpts []ldap.DialOpt dialOpts []ldap.DialOpt
log *logrus.Logger log *logrus.Logger
factory LDAPConnectionFactory factory LDAPClientFactory
disableResetPassword bool disableResetPassword bool
// Automatically detected ldap features. // Automatically detected LDAP features.
supportExtensionPasswdModify bool features LDAPSupportedFeatures
// Dynamically generated users values. // Dynamically generated users values.
usersBaseDN string usersBaseDN string
@ -48,7 +48,7 @@ func NewLDAPUserProvider(config schema.AuthenticationBackendConfiguration, certP
return provider return provider
} }
func newLDAPUserProvider(config schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) { func newLDAPUserProvider(config schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPClientFactory) (provider *LDAPUserProvider) {
if config.TLS == nil { if config.TLS == nil {
config.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS config.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS
} }
@ -64,7 +64,7 @@ func newLDAPUserProvider(config schema.LDAPAuthenticationBackendConfiguration, d
} }
if factory == nil { if factory == nil {
factory = NewProductionLDAPConnectionFactory() factory = NewProductionLDAPClientFactory()
} }
provider = &LDAPUserProvider{ provider = &LDAPUserProvider{
@ -83,27 +83,27 @@ func newLDAPUserProvider(config schema.LDAPAuthenticationBackendConfiguration, d
} }
// CheckUserPassword checks if provided password matches for the given user. // CheckUserPassword checks if provided password matches for the given user.
func (p *LDAPUserProvider) CheckUserPassword(inputUsername string, password string) (valid bool, err error) { func (p *LDAPUserProvider) CheckUserPassword(username string, password string) (valid bool, err error) {
var ( var (
conn, connUser LDAPConnection client, clientUser LDAPClient
profile *ldapUserProfile profile *ldapUserProfile
) )
if conn, err = p.connect(); err != nil { if client, err = p.connect(); err != nil {
return false, err return false, err
} }
defer conn.Close() defer client.Close()
if profile, err = p.getUserProfile(conn, inputUsername); err != nil { if profile, err = p.getUserProfile(client, username); err != nil {
return false, err return false, err
} }
if connUser, err = p.connectCustom(p.config.URL, profile.DN, password, p.config.StartTLS, p.dialOpts...); err != nil { if clientUser, err = p.connectCustom(p.config.URL, profile.DN, password, p.config.StartTLS, p.dialOpts...); err != nil {
return false, fmt.Errorf("authentication failed. Cause: %w", err) return false, fmt.Errorf("authentication failed. Cause: %w", err)
} }
defer connUser.Close() defer clientUser.Close()
return true, nil return true, nil
} }
@ -111,17 +111,17 @@ func (p *LDAPUserProvider) CheckUserPassword(inputUsername string, password stri
// GetDetails retrieve the groups a user belongs to. // GetDetails retrieve the groups a user belongs to.
func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, err error) { func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, err error) {
var ( var (
conn LDAPConnection client LDAPClient
profile *ldapUserProfile profile *ldapUserProfile
) )
if conn, err = p.connect(); err != nil { if client, err = p.connect(); err != nil {
return nil, err return nil, err
} }
defer conn.Close() defer client.Close()
if profile, err = p.getUserProfile(conn, username); err != nil { if profile, err = p.getUserProfile(client, username); err != nil {
return nil, err return nil, err
} }
@ -141,7 +141,7 @@ func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, er
0, 0, false, filter, p.groupsAttributes, nil, 0, 0, false, filter, p.groupsAttributes, nil,
) )
if searchResult, err = p.search(conn, searchRequest); err != nil { if searchResult, err = p.search(client, searchRequest); err != nil {
return nil, fmt.Errorf("unable to retrieve groups of user '%s'. Cause: %w", username, err) return nil, fmt.Errorf("unable to retrieve groups of user '%s'. Cause: %w", username, err)
} }
@ -168,31 +168,38 @@ func (p *LDAPUserProvider) GetDetails(username string) (details *UserDetails, er
// UpdatePassword update the password of the given user. // UpdatePassword update the password of the given user.
func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error) { func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error) {
var ( var (
conn LDAPConnection client LDAPClient
profile *ldapUserProfile profile *ldapUserProfile
) )
if conn, err = p.connect(); err != nil { if client, err = p.connect(); err != nil {
return fmt.Errorf("unable to update password. Cause: %w", err) return fmt.Errorf("unable to update password. Cause: %w", err)
} }
defer conn.Close() defer client.Close()
if profile, err = p.getUserProfile(conn, username); err != nil { if profile, err = p.getUserProfile(client, username); err != nil {
return fmt.Errorf("unable to update password. Cause: %w", err) return fmt.Errorf("unable to update password. Cause: %w", err)
} }
var controls []ldap.Control var controls []ldap.Control
switch { switch {
case p.supportExtensionPasswdModify: case p.features.ControlTypes.MsftPwdPolHints:
controls = append(controls, &controlMsftServerPolicyHints{ldapOIDControlMsftServerPolicyHints})
case p.features.ControlTypes.MsftPwdPolHintsDeprecated:
controls = append(controls, &controlMsftServerPolicyHints{ldapOIDControlMsftServerPolicyHintsDeprecated})
}
switch {
case p.features.Extensions.PwdModifyExOp:
pwdModifyRequest := ldap.NewPasswordModifyRequest( pwdModifyRequest := ldap.NewPasswordModifyRequest(
profile.DN, profile.DN,
"", "",
password, password,
) )
err = p.pwdModify(conn, pwdModifyRequest) err = p.pwdModify(client, pwdModifyRequest)
case p.config.Implementation == schema.LDAPImplementationActiveDirectory: case p.config.Implementation == schema.LDAPImplementationActiveDirectory:
modifyRequest := ldap.NewModifyRequest(profile.DN, controls) modifyRequest := ldap.NewModifyRequest(profile.DN, controls)
// The password needs to be enclosed in quotes // The password needs to be enclosed in quotes
@ -200,12 +207,12 @@ func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error)
pwdEncoded, _ := utf16LittleEndian.NewEncoder().String(fmt.Sprintf("\"%s\"", password)) pwdEncoded, _ := utf16LittleEndian.NewEncoder().String(fmt.Sprintf("\"%s\"", password))
modifyRequest.Replace(ldapAttributeUnicodePwd, []string{pwdEncoded}) modifyRequest.Replace(ldapAttributeUnicodePwd, []string{pwdEncoded})
err = p.modify(conn, modifyRequest) err = p.modify(client, modifyRequest)
default: default:
modifyRequest := ldap.NewModifyRequest(profile.DN, controls) modifyRequest := ldap.NewModifyRequest(profile.DN, controls)
modifyRequest.Replace(ldapAttributeUserPassword, []string{password}) modifyRequest.Replace(ldapAttributeUserPassword, []string{password})
err = p.modify(conn, modifyRequest) err = p.modify(client, modifyRequest)
} }
if err != nil { if err != nil {
@ -215,73 +222,74 @@ func (p *LDAPUserProvider) UpdatePassword(username, password string) (err error)
return nil return nil
} }
func (p *LDAPUserProvider) connect() (LDAPConnection, error) { func (p *LDAPUserProvider) connect() (client LDAPClient, err error) {
return p.connectCustom(p.config.URL, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...) return p.connectCustom(p.config.URL, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...)
} }
func (p *LDAPUserProvider) connectCustom(url, userDN, password string, startTLS bool, opts ...ldap.DialOpt) (conn LDAPConnection, err error) { func (p *LDAPUserProvider) connectCustom(url, userDN, password string, startTLS bool, opts ...ldap.DialOpt) (client LDAPClient, err error) {
if conn, err = p.factory.DialURL(url, opts...); err != nil { if client, err = p.factory.DialURL(url, opts...); err != nil {
return nil, fmt.Errorf("dial failed with error: %w", err) return nil, fmt.Errorf("dial failed with error: %w", err)
} }
if startTLS { if startTLS {
if err = conn.StartTLS(p.tlsConfig); err != nil { if err = client.StartTLS(p.tlsConfig); err != nil {
client.Close()
return nil, fmt.Errorf("starttls failed with error: %w", err) return nil, fmt.Errorf("starttls failed with error: %w", err)
} }
} }
if err = conn.Bind(userDN, password); err != nil { if err = client.Bind(userDN, password); err != nil {
client.Close()
return nil, fmt.Errorf("bind failed with error: %w", err) return nil, fmt.Errorf("bind failed with error: %w", err)
} }
return conn, nil return client, nil
} }
func (p *LDAPUserProvider) search(conn LDAPConnection, searchRequest *ldap.SearchRequest) (searchResult *ldap.SearchResult, err error) { func (p *LDAPUserProvider) search(client LDAPClient, searchRequest *ldap.SearchRequest) (searchResult *ldap.SearchResult, err error) {
searchResult, err = conn.Search(searchRequest) if searchResult, err = client.Search(searchRequest); err != nil {
if err != nil {
if referral, ok := p.getReferral(err); ok { if referral, ok := p.getReferral(err); ok {
if errReferral := p.searchReferral(referral, searchRequest, searchResult); errReferral != nil { if searchResult == nil {
return nil, err searchResult = &ldap.SearchResult{
Referrals: []string{referral},
}
} else {
searchResult.Referrals = append(searchResult.Referrals, referral)
} }
return searchResult, nil
} }
return nil, err
} }
if !p.config.PermitReferrals || len(searchResult.Referrals) == 0 { if !p.config.PermitReferrals || len(searchResult.Referrals) == 0 {
if err != nil {
return nil, err
}
return searchResult, nil return searchResult, nil
} }
p.searchReferrals(searchRequest, searchResult) if err = p.searchReferrals(searchRequest, searchResult); err != nil {
return nil, err
}
return searchResult, nil return searchResult, nil
} }
func (p *LDAPUserProvider) searchReferral(referral string, searchRequest *ldap.SearchRequest, searchResult *ldap.SearchResult) (err error) { func (p *LDAPUserProvider) searchReferral(referral string, searchRequest *ldap.SearchRequest, searchResult *ldap.SearchResult) (err error) {
var ( var (
conn LDAPConnection client LDAPClient
result *ldap.SearchResult result *ldap.SearchResult
) )
if conn, err = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); err != nil { if client, err = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); err != nil {
p.log.Errorf("Failed to connect during referred search request (referred to %s): %v", referral, err) return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %w", referral, err)
return err
} }
defer conn.Close() defer client.Close()
if result, err = conn.Search(searchRequest); err != nil { if result, err = client.Search(searchRequest); err != nil {
p.log.Errorf("Failed to perform search operation during referred search request (referred to %s): %v", referral, err) return fmt.Errorf("error occurred performing search on referred LDAP server '%s': %w", referral, err)
return err
}
if len(result.Entries) == 0 {
return err
} }
for i := 0; i < len(result.Entries); i++ { for i := 0; i < len(result.Entries); i++ {
@ -293,14 +301,18 @@ func (p *LDAPUserProvider) searchReferral(referral string, searchRequest *ldap.S
return nil return nil
} }
func (p *LDAPUserProvider) searchReferrals(searchRequest *ldap.SearchRequest, searchResult *ldap.SearchResult) { func (p *LDAPUserProvider) searchReferrals(searchRequest *ldap.SearchRequest, searchResult *ldap.SearchResult) (err error) {
for i := 0; i < len(searchResult.Referrals); i++ { for i := 0; i < len(searchResult.Referrals); i++ {
_ = p.searchReferral(searchResult.Referrals[i], searchRequest, searchResult) if err = p.searchReferral(searchResult.Referrals[i], searchRequest, searchResult); err != nil {
return err
} }
}
return nil
} }
func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername string) (profile *ldapUserProfile, err error) { func (p *LDAPUserProvider) getUserProfile(client LDAPClient, username string) (profile *ldapUserProfile, err error) {
userFilter := p.resolveUsersFilter(inputUsername) userFilter := p.resolveUsersFilter(username)
// Search for the given username. // Search for the given username.
searchRequest := ldap.NewSearchRequest( searchRequest := ldap.NewSearchRequest(
@ -310,8 +322,8 @@ func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername str
var searchResult *ldap.SearchResult var searchResult *ldap.SearchResult
if searchResult, err = p.search(conn, searchRequest); err != nil { if searchResult, err = p.search(client, searchRequest); err != nil {
return nil, fmt.Errorf("cannot find user DN of user '%s'. Cause: %w", inputUsername, err) return nil, fmt.Errorf("cannot find user DN of user '%s'. Cause: %w", username, err)
} }
if len(searchResult.Entries) == 0 { if len(searchResult.Entries) == 0 {
@ -319,7 +331,7 @@ func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername str
} }
if len(searchResult.Entries) > 1 { if len(searchResult.Entries) > 1 {
return nil, fmt.Errorf("multiple users %s found", inputUsername) return nil, fmt.Errorf("there were %d users found when searching for '%s' but there should only be 1", len(searchResult.Entries), username)
} }
userProfile := ldapUserProfile{ userProfile := ldapUserProfile{
@ -327,37 +339,45 @@ func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername str
} }
for _, attr := range searchResult.Entries[0].Attributes { for _, attr := range searchResult.Entries[0].Attributes {
if attr.Name == p.config.DisplayNameAttribute { switch attr.Name {
case p.config.DisplayNameAttribute:
userProfile.DisplayName = attr.Values[0] userProfile.DisplayName = attr.Values[0]
} case p.config.MailAttribute:
if attr.Name == p.config.MailAttribute {
userProfile.Emails = attr.Values userProfile.Emails = attr.Values
} case p.config.UsernameAttribute:
attrs := len(attr.Values)
if attr.Name == p.config.UsernameAttribute {
if len(attr.Values) != 1 {
return nil, fmt.Errorf("user '%s' cannot have multiple value for attribute '%s'",
inputUsername, p.config.UsernameAttribute)
}
switch attrs {
case 1:
userProfile.Username = attr.Values[0] userProfile.Username = attr.Values[0]
case 0:
return nil, fmt.Errorf("user '%s' must have value for attribute '%s'",
username, p.config.UsernameAttribute)
default:
return nil, fmt.Errorf("user '%s' has %d values for for attribute '%s' but the attribute must be a single value attribute",
username, attrs, p.config.UsernameAttribute)
} }
} }
}
if userProfile.Username == "" {
return nil, fmt.Errorf("user '%s' must have value for attribute '%s'",
username, p.config.UsernameAttribute)
}
if userProfile.DN == "" { if userProfile.DN == "" {
return nil, fmt.Errorf("no DN has been found for user %s", inputUsername) return nil, fmt.Errorf("user '%s' must have a distinguished name but the result returned an empty distinguished name", username)
} }
return &userProfile, nil return &userProfile, nil
} }
func (p *LDAPUserProvider) resolveUsersFilter(inputUsername string) (filter string) { func (p *LDAPUserProvider) resolveUsersFilter(username string) (filter string) {
filter = p.config.UsersFilter filter = p.config.UsersFilter
if p.usersFilterReplacementInput { if p.usersFilterReplacementInput {
// The {input} placeholder is replaced by the username input. // The {input} placeholder is replaced by the username input.
filter = strings.ReplaceAll(filter, ldapPlaceholderInput, ldapEscape(inputUsername)) filter = strings.ReplaceAll(filter, ldapPlaceholderInput, ldapEscape(username))
} }
p.log.Tracef("Detected user filter is %s", filter) p.log.Tracef("Detected user filter is %s", filter)
@ -365,12 +385,12 @@ func (p *LDAPUserProvider) resolveUsersFilter(inputUsername string) (filter stri
return filter return filter
} }
func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ldapUserProfile) (filter string, err error) { //nolint:unparam func (p *LDAPUserProvider) resolveGroupsFilter(username string, profile *ldapUserProfile) (filter string, err error) { //nolint:unparam
filter = p.config.GroupsFilter filter = p.config.GroupsFilter
if p.groupsFilterReplacementInput { if p.groupsFilterReplacementInput {
// The {input} placeholder is replaced by the users username input. // The {input} placeholder is replaced by the users username input.
filter = strings.ReplaceAll(p.config.GroupsFilter, ldapPlaceholderInput, ldapEscape(inputUsername)) filter = strings.ReplaceAll(p.config.GroupsFilter, ldapPlaceholderInput, ldapEscape(username))
} }
if profile != nil { if profile != nil {
@ -388,8 +408,8 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld
return filter, nil return filter, nil
} }
func (p *LDAPUserProvider) modify(conn LDAPConnection, modifyRequest *ldap.ModifyRequest) (err error) { func (p *LDAPUserProvider) modify(client LDAPClient, modifyRequest *ldap.ModifyRequest) (err error) {
if err = conn.Modify(modifyRequest); err != nil { if err = client.Modify(modifyRequest); err != nil {
var ( var (
referral string referral string
ok bool ok bool
@ -402,28 +422,28 @@ func (p *LDAPUserProvider) modify(conn LDAPConnection, modifyRequest *ldap.Modif
p.log.Debugf("Attempting Modify on referred URL %s", referral) p.log.Debugf("Attempting Modify on referred URL %s", referral)
var ( var (
connReferral LDAPConnection clientRef LDAPClient
errReferral error errRef error
) )
if connReferral, errReferral = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errReferral != nil { if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil {
p.log.Errorf("Failed to connect during referred modify request (referred to %s): %v", referral, errReferral) return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
return err
} }
defer connReferral.Close() defer clientRef.Close()
if errReferral = connReferral.Modify(modifyRequest); errReferral != nil { if errRef = clientRef.Modify(modifyRequest); errRef != nil {
p.log.Errorf("Failed to perform modify operation during referred modify request (referred to %s): %v", referral, errReferral) return fmt.Errorf("error occurred performing modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
}
} }
return err return nil
}
return nil
} }
func (p *LDAPUserProvider) pwdModify(conn LDAPConnection, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) { func (p *LDAPUserProvider) pwdModify(client LDAPClient, pwdModifyRequest *ldap.PasswordModifyRequest) (err error) {
if _, err = conn.PasswordModify(pwdModifyRequest); err != nil { if _, err = client.PasswordModify(pwdModifyRequest); err != nil {
var ( var (
referral string referral string
ok bool ok bool
@ -436,24 +456,24 @@ func (p *LDAPUserProvider) pwdModify(conn LDAPConnection, pwdModifyRequest *ldap
p.log.Debugf("Attempting PwdModify ExOp (1.3.6.1.4.1.4203.1.11.1) on referred URL %s", referral) p.log.Debugf("Attempting PwdModify ExOp (1.3.6.1.4.1.4203.1.11.1) on referred URL %s", referral)
var ( var (
connReferral LDAPConnection clientRef LDAPClient
errReferral error errRef error
) )
if connReferral, errReferral = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errReferral != nil { if clientRef, errRef = p.connectCustom(referral, p.config.User, p.config.Password, p.config.StartTLS, p.dialOpts...); errRef != nil {
p.log.Errorf("Failed to connect during referred password modify request (referred to %s): %v", referral, errReferral) return fmt.Errorf("error occurred connecting to referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
return err
} }
defer connReferral.Close() defer clientRef.Close()
if _, errReferral = connReferral.PasswordModify(pwdModifyRequest); errReferral != nil { if _, errRef = clientRef.PasswordModify(pwdModifyRequest); errRef != nil {
p.log.Errorf("Failed to perform modify operation during referred modify request (referred to %s): %v", referral, errReferral) return fmt.Errorf("error occurred performing password modify on referred LDAP server '%s': %+v. Original Error: %w", referral, errRef, err)
}
} }
return err return nil
}
return nil
} }
func (p *LDAPUserProvider) getReferral(err error) (referral string, ok bool) { func (p *LDAPUserProvider) getReferral(err error) (referral string, ok bool) {

View File

@ -10,52 +10,77 @@ import (
// StartupCheck implements the startup check provider interface. // StartupCheck implements the startup check provider interface.
func (p *LDAPUserProvider) StartupCheck() (err error) { func (p *LDAPUserProvider) StartupCheck() (err error) {
var ( var client LDAPClient
conn LDAPConnection
searchResult *ldap.SearchResult
)
if conn, err = p.connect(); err != nil { if client, err = p.connect(); err != nil {
return err return err
} }
defer conn.Close() defer client.Close()
searchRequest := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, if p.features, err = p.getServerSupportedFeatures(client); err != nil {
1, 0, false, "(objectClass=*)", []string{ldapSupportedExtensionAttribute}, nil)
if searchResult, err = conn.Search(searchRequest); err != nil {
return err return err
} }
if len(searchResult.Entries) != 1 { if !p.features.Extensions.PwdModifyExOp && !p.disableResetPassword &&
return nil
}
// Iterate the attribute values to see what the server supports.
for _, attr := range searchResult.Entries[0].Attributes {
if attr.Name == ldapSupportedExtensionAttribute {
p.log.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
for _, oid := range attr.Values {
if oid == ldapOIDPasswdModifyExtension {
p.supportExtensionPasswdModify = true
break
}
}
}
}
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
p.config.Implementation != schema.LDAPImplementationActiveDirectory { p.config.Implementation != schema.LDAPImplementationActiveDirectory {
p.log.Warn("Your LDAP server implementation may not support a method for password hashing " + p.log.Warn("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " + "known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.") "attribute when users reset their password via Authelia.")
} }
if p.features.Extensions.TLS && !p.config.StartTLS && !strings.HasPrefix(p.config.URL, "ldaps://") {
p.log.Error("Your LDAP Server supports TLS but you don't appear to be utilizing it. We strongly" +
"recommend enabling the StartTLS option or using the scheme 'ldaps://' to secure connections with your" +
"LDAP Server.")
}
if !p.features.Extensions.TLS && p.config.StartTLS {
p.log.Info("Your LDAP Server does not appear to support TLS but you enabled StartTLS which may result" +
"in an error.")
}
return nil return nil
} }
func (p *LDAPUserProvider) getServerSupportedFeatures(client LDAPClient) (features LDAPSupportedFeatures, err error) {
var (
searchRequest *ldap.SearchRequest
searchResult *ldap.SearchResult
)
searchRequest = ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases,
1, 0, false, "(objectClass=*)", []string{ldapSupportedExtensionAttribute, ldapSupportedControlAttribute}, nil)
if searchResult, err = client.Search(searchRequest); err != nil {
return features, err
}
if len(searchResult.Entries) != 1 {
p.log.Errorf("The LDAP Server did not respond appropriately to a RootDSE search. This may result in reduced functionality.")
return features, nil
}
var controlTypeOIDs, extensionOIDs []string
controlTypeOIDs, extensionOIDs, features = ldapGetFeatureSupportFromEntry(searchResult.Entries[0])
controlTypes, extensions := none, none
if len(controlTypeOIDs) != 0 {
controlTypes = strings.Join(controlTypeOIDs, ", ")
}
if len(extensionOIDs) != 0 {
extensions = strings.Join(extensionOIDs, ", ")
}
p.log.Debugf("LDAP Supported OIDs. Control Types: %s. Extensions: %s", controlTypes, extensions)
return features, nil
}
func (p *LDAPUserProvider) parseDynamicUsersConfiguration() { func (p *LDAPUserProvider) parseDynamicUsersConfiguration() {
p.config.UsersFilter = strings.ReplaceAll(p.config.UsersFilter, "{username_attribute}", p.config.UsernameAttribute) p.config.UsersFilter = strings.ReplaceAll(p.config.UsersFilter, "{username_attribute}", p.config.UsernameAttribute)
p.config.UsersFilter = strings.ReplaceAll(p.config.UsersFilter, "{mail_attribute}", p.config.MailAttribute) p.config.UsersFilter = strings.ReplaceAll(p.config.UsersFilter, "{mail_attribute}", p.config.MailAttribute)

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,10 @@ import (
) )
func ldapEntriesContainsEntry(needle *ldap.Entry, haystack []*ldap.Entry) bool { func ldapEntriesContainsEntry(needle *ldap.Entry, haystack []*ldap.Entry) bool {
if needle == nil || len(haystack) == 0 {
return false
}
for i := 0; i < len(haystack); i++ { for i := 0; i < len(haystack); i++ {
if haystack[i].DN == needle.DN { if haystack[i].DN == needle.DN {
return true return true
@ -18,6 +22,41 @@ func ldapEntriesContainsEntry(needle *ldap.Entry, haystack []*ldap.Entry) bool {
return false return false
} }
func ldapGetFeatureSupportFromEntry(entry *ldap.Entry) (controlTypeOIDs, extensionOIDs []string, features LDAPSupportedFeatures) {
if entry == nil {
return controlTypeOIDs, extensionOIDs, features
}
for _, attr := range entry.Attributes {
switch attr.Name {
case ldapSupportedControlAttribute:
controlTypeOIDs = attr.Values
for _, oid := range attr.Values {
switch oid {
case ldapOIDControlMsftServerPolicyHints:
features.ControlTypes.MsftPwdPolHints = true
case ldapOIDControlMsftServerPolicyHintsDeprecated:
features.ControlTypes.MsftPwdPolHintsDeprecated = true
}
}
case ldapSupportedExtensionAttribute:
extensionOIDs = attr.Values
for _, oid := range attr.Values {
switch oid {
case ldapOIDExtensionPwdModifyExOp:
features.Extensions.PwdModifyExOp = true
case ldapOIDExtensionTLS:
features.Extensions.TLS = true
}
}
}
}
return controlTypeOIDs, extensionOIDs, features
}
func ldapEscape(inputUsername string) string { func ldapEscape(inputUsername string) string {
inputUsername = ldap.EscapeFilter(inputUsername) inputUsername = ldap.EscapeFilter(inputUsername)
for _, c := range specialLDAPRunes { for _, c := range specialLDAPRunes {
@ -34,10 +73,18 @@ func ldapGetReferral(err error) (referral string, ok bool) {
switch e := err.(type) { switch e := err.(type) {
case *ldap.Error: case *ldap.Error:
if e.Packet == nil {
return "", false
}
if len(e.Packet.Children) < 2 { if len(e.Packet.Children) < 2 {
return "", false return "", false
} }
if e.Packet.Children[1].Tag != ber.TagObjectDescriptor {
return "", false
}
for i := 0; i < len(e.Packet.Children[1].Children); i++ { for i := 0; i < len(e.Packet.Children[1].Children); i++ {
if e.Packet.Children[1].Children[i].Tag != ber.TagBitString || len(e.Packet.Children[1].Children[i].Children) < 1 { if e.Packet.Children[1].Children[i].Tag != ber.TagBitString || len(e.Packet.Children[1].Children[i].Children) < 1 {
continue continue

View File

@ -0,0 +1,320 @@
package authentication
import (
"errors"
"testing"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/go-ldap/ldap/v3"
"github.com/stretchr/testify/assert"
)
func TestLDAPGetFeatureSupportFromNilEntry(t *testing.T) {
control, extension, feature := ldapGetFeatureSupportFromEntry(nil)
assert.Len(t, control, 0)
assert.Len(t, extension, 0)
assert.Equal(t, LDAPSupportedFeatures{}, feature)
}
func TestLDAPGetFeatureSupportFromEntry(t *testing.T) {
testCases := []struct {
description string
haveControlOIDs, haveExtensionOIDs []string
expected LDAPSupportedFeatures
}{
{
description: "ShouldReturnExtensionPwdModifyExOp",
haveControlOIDs: []string{},
haveExtensionOIDs: []string{ldapOIDExtensionPwdModifyExOp},
expected: LDAPSupportedFeatures{Extensions: LDAPSupportedExtensions{PwdModifyExOp: true}},
},
{
description: "ShouldReturnExtensionTLS",
haveControlOIDs: []string{},
haveExtensionOIDs: []string{ldapOIDExtensionTLS},
expected: LDAPSupportedFeatures{Extensions: LDAPSupportedExtensions{TLS: true}},
},
{
description: "ShouldReturnExtensionAll",
haveControlOIDs: []string{},
haveExtensionOIDs: []string{ldapOIDExtensionTLS, ldapOIDExtensionPwdModifyExOp},
expected: LDAPSupportedFeatures{Extensions: LDAPSupportedExtensions{TLS: true, PwdModifyExOp: true}},
},
{
description: "ShouldReturnControlMsftPPolHints",
haveControlOIDs: []string{ldapOIDControlMsftServerPolicyHints},
haveExtensionOIDs: []string{},
expected: LDAPSupportedFeatures{ControlTypes: LDAPSupportedControlTypes{MsftPwdPolHints: true}},
},
{
description: "ShouldReturnControlMsftPPolHintsDeprecated",
haveControlOIDs: []string{ldapOIDControlMsftServerPolicyHintsDeprecated},
haveExtensionOIDs: []string{},
expected: LDAPSupportedFeatures{ControlTypes: LDAPSupportedControlTypes{MsftPwdPolHintsDeprecated: true}},
},
{
description: "ShouldReturnControlAll",
haveControlOIDs: []string{ldapOIDControlMsftServerPolicyHints, ldapOIDControlMsftServerPolicyHintsDeprecated},
haveExtensionOIDs: []string{},
expected: LDAPSupportedFeatures{ControlTypes: LDAPSupportedControlTypes{MsftPwdPolHints: true, MsftPwdPolHintsDeprecated: true}},
},
{
description: "ShouldReturnExtensionAndControlAll",
haveControlOIDs: []string{ldapOIDControlMsftServerPolicyHints, ldapOIDControlMsftServerPolicyHintsDeprecated},
haveExtensionOIDs: []string{ldapOIDExtensionTLS, ldapOIDExtensionPwdModifyExOp},
expected: LDAPSupportedFeatures{
ControlTypes: LDAPSupportedControlTypes{MsftPwdPolHints: true, MsftPwdPolHintsDeprecated: true},
Extensions: LDAPSupportedExtensions{TLS: true, PwdModifyExOp: true},
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
entry := &ldap.Entry{
DN: "",
Attributes: []*ldap.EntryAttribute{
{Name: ldapSupportedExtensionAttribute, Values: tc.haveExtensionOIDs},
{Name: ldapSupportedControlAttribute, Values: tc.haveControlOIDs},
},
}
actualControlOIDs, actualExtensionOIDs, actual := ldapGetFeatureSupportFromEntry(entry)
assert.Equal(t, tc.haveExtensionOIDs, actualExtensionOIDs)
assert.Equal(t, tc.haveControlOIDs, actualControlOIDs)
assert.Equal(t, tc.expected, actual)
})
}
}
func TestLDAPEntriesContainsEntry(t *testing.T) {
testCases := []struct {
description string
have []*ldap.Entry
lookingFor *ldap.Entry
expected bool
}{
{
description: "ShouldNotMatchNil",
have: []*ldap.Entry{
{DN: "test"},
},
lookingFor: nil,
expected: false,
},
{
description: "ShouldMatch",
have: []*ldap.Entry{
{DN: "test"},
},
lookingFor: &ldap.Entry{DN: "test"},
expected: true,
},
{
description: "ShouldMatchWhenMultiple",
have: []*ldap.Entry{
{DN: "False"},
{DN: "test"},
},
lookingFor: &ldap.Entry{DN: "test"},
expected: true,
},
{
description: "ShouldNotMatchDifferent",
have: []*ldap.Entry{
{DN: "False"},
{DN: "test"},
},
lookingFor: &ldap.Entry{DN: "not a result"},
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
assert.Equal(t, tc.expected, ldapEntriesContainsEntry(tc.lookingFor, tc.have))
})
}
}
func TestLDAPGetReferral(t *testing.T) {
testCases := []struct {
description string
have error
expectedReferral string
expectedOK bool
}{
{
description: "ShouldGetValidPacket",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: &testBERPacketReferral},
expectedReferral: "ldap://192.168.0.1",
expectedOK: true,
},
{
description: "ShouldNotGetNilPacket",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: nil},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidPacketWithNoObjectDescriptor",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: &testBERPacketReferralInvalidObjectDescriptor},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidPacketWithBadErrorCode",
have: &ldap.Error{ResultCode: ldap.LDAPResultBusy, Packet: &testBERPacketReferral},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidPacketWithoutBitString",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: &testBERPacketReferralWithoutBitString},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidPacketWithInvalidBitString",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: &testBERPacketReferralWithInvalidBitString},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidPacketWithoutEnoughChildren",
have: &ldap.Error{ResultCode: ldap.LDAPResultReferral, Packet: &testBERPacketReferralWithoutEnoughChildren},
expectedReferral: "",
expectedOK: false,
},
{
description: "ShouldNotGetInvalidErrType",
have: errors.New("not an err"),
expectedReferral: "",
expectedOK: false,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
referral, ok := ldapGetReferral(tc.have)
assert.Equal(t, tc.expectedOK, ok)
assert.Equal(t, tc.expectedReferral, referral)
})
}
}
var testBERPacketReferral = ber.Packet{
Children: []*ber.Packet{
{},
{
Identifier: ber.Identifier{
Tag: ber.TagObjectDescriptor,
},
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagBitString,
},
Children: []*ber.Packet{
{
Value: "ldap://192.168.0.1",
},
},
},
},
},
},
}
var testBERPacketReferralInvalidObjectDescriptor = ber.Packet{
Children: []*ber.Packet{
{},
{
Identifier: ber.Identifier{
Tag: ber.TagEOC,
},
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagBitString,
},
Children: []*ber.Packet{
{
Value: "ldap://192.168.0.1",
},
},
},
},
},
},
}
var testBERPacketReferralWithoutBitString = ber.Packet{
Children: []*ber.Packet{
{},
{
Identifier: ber.Identifier{
Tag: ber.TagObjectDescriptor,
},
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagSequence,
},
Children: []*ber.Packet{
{
Value: "ldap://192.168.0.1",
},
},
},
},
},
},
}
var testBERPacketReferralWithInvalidBitString = ber.Packet{
Children: []*ber.Packet{
{},
{
Identifier: ber.Identifier{
Tag: ber.TagObjectDescriptor,
},
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagBitString,
},
Children: []*ber.Packet{
{
Value: 55,
},
},
},
},
},
},
}
var testBERPacketReferralWithoutEnoughChildren = ber.Packet{
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagEOC,
},
Children: []*ber.Packet{
{
Identifier: ber.Identifier{
Tag: ber.TagBitString,
},
Children: []*ber.Packet{
{
Value: "ldap://192.168.0.1",
},
},
},
},
},
},
}

View File

@ -7,21 +7,24 @@ import (
"golang.org/x/text/encoding/unicode" "golang.org/x/text/encoding/unicode"
) )
// LDAPConnectionFactory an interface of factory of ldap connections. // LDAPClientFactory an interface of factory of LDAP clients.
type LDAPConnectionFactory interface { type LDAPClientFactory interface {
DialURL(addr string, opts ...ldap.DialOpt) (LDAPConnection, error) DialURL(addr string, opts ...ldap.DialOpt) (client LDAPClient, err error)
} }
// LDAPConnection interface representing a connection to the ldap. // LDAPClient is a cut down version of the ldap.Client interface with just the methods we use.
type LDAPConnection interface { //
Bind(username, password string) (err error) // Methods added to this interface that have a direct correlation with one from ldap.Client should have the same signature.
type LDAPClient interface {
Close() Close()
StartTLS(config *tls.Config) (err error) StartTLS(config *tls.Config) (err error)
Search(searchRequest *ldap.SearchRequest) (searchResult *ldap.SearchResult, err error) Bind(username, password string) (err error)
Modify(modifyRequest *ldap.ModifyRequest) (err error) Modify(modifyRequest *ldap.ModifyRequest) (err error)
PasswordModify(pwdModifyRequest *ldap.PasswordModifyRequest) (result *ldap.PasswordModifyResult, err error) PasswordModify(pwdModifyRequest *ldap.PasswordModifyRequest) (pwdModifyResult *ldap.PasswordModifyResult, err error)
Search(searchRequest *ldap.SearchRequest) (searchResult *ldap.SearchResult, err error)
} }
// UserDetails represent the details retrieved for a given user. // UserDetails represent the details retrieved for a given user.
@ -39,4 +42,22 @@ type ldapUserProfile struct {
Username string Username string
} }
// LDAPSupportedFeatures represents features which a server may support which are implemented in code.
type LDAPSupportedFeatures struct {
Extensions LDAPSupportedExtensions
ControlTypes LDAPSupportedControlTypes
}
// LDAPSupportedExtensions represents extensions which a server may support which are implemented in code.
type LDAPSupportedExtensions struct {
TLS bool
PwdModifyExOp bool
}
// LDAPSupportedControlTypes represents control types which a server may support which are implemented in code.
type LDAPSupportedControlTypes struct {
MsftPwdPolHints bool
MsftPwdPolHintsDeprecated bool
}
var utf16LittleEndian = unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) var utf16LittleEndian = unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)