diff --git a/internal/authentication/ldap_user_provider.go b/internal/authentication/ldap_user_provider.go index 4e44162d7..f1dc355c0 100644 --- a/internal/authentication/ldap_user_provider.go +++ b/internal/authentication/ldap_user_provider.go @@ -89,6 +89,17 @@ func (p *LDAPUserProvider) CheckUserPassword(username string, password string) ( return true, nil } +// OWASP recommends to escape some special characters +// https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/LDAP_Injection_Prevention_Cheat_Sheet.md +const SpecialLDAPRunes = "\\,#+<>;\"=" + +func (p *LDAPUserProvider) ldapEscape(input string) string { + for _, c := range SpecialLDAPRunes { + input = strings.ReplaceAll(input, string(c), fmt.Sprintf("\\%c", c)) + } + return input +} + func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string, attribute string) ([]string, error) { client, err := p.connect(p.configuration.User, p.configuration.Password) if err != nil { @@ -96,6 +107,7 @@ func (p *LDAPUserProvider) getUserAttribute(conn LDAPConnection, username string } defer client.Close() + username = p.ldapEscape(username) userFilter := strings.Replace(p.configuration.UsersFilter, "{0}", username, -1) baseDN := p.configuration.BaseDN if p.configuration.AdditionalUsersDN != "" { diff --git a/internal/authentication/ldap_user_provider_test.go b/internal/authentication/ldap_user_provider_test.go index 00b0f7c61..fc77b7131 100644 --- a/internal/authentication/ldap_user_provider_test.go +++ b/internal/authentication/ldap_user_provider_test.go @@ -5,7 +5,9 @@ import ( "github.com/authelia/authelia/internal/configuration/schema" gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/ldap.v3" ) func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) { @@ -55,3 +57,80 @@ func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) { require.NoError(t, err) } + +func TestEscapeSpecialChars(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFactory := NewMockLDAPConnectionFactory(ctrl) + ldap := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{ + URL: "ldaps://127.0.0.1:389", + }, mockFactory) + + // No escape + assert.Equal(t, "xyz", ldap.ldapEscape("xyz")) + + // Escape + assert.Equal(t, "test\\,abc", ldap.ldapEscape("test,abc")) + assert.Equal(t, "test\\\\abc", ldap.ldapEscape("test\\abc")) + assert.Equal(t, "test\\#abc", ldap.ldapEscape("test#abc")) + assert.Equal(t, "test\\+abc", ldap.ldapEscape("test+abc")) + assert.Equal(t, "test\\abc", ldap.ldapEscape("test>abc")) + assert.Equal(t, "test\\;abc", ldap.ldapEscape("test;abc")) + assert.Equal(t, "test\\\"abc", ldap.ldapEscape("test\"abc")) + assert.Equal(t, "test\\=abc", ldap.ldapEscape("test=abc")) + +} + +type SearchRequestMatcher struct { + expected string +} + +func NewSearchRequestMatcher(expected string) *SearchRequestMatcher { + return &SearchRequestMatcher{expected} +} + +func (srm *SearchRequestMatcher) Matches(x interface{}) bool { + sr := x.(*ldap.SearchRequest) + return sr.Filter == srm.expected +} + +func (srm *SearchRequestMatcher) String() string { + return "" +} + +func TestShouldEscapeUserInput(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFactory := NewMockLDAPConnectionFactory(ctrl) + mockConn := NewMockLDAPConnection(ctrl) + + ldapClient := NewLDAPUserProviderWithFactory(schema.LDAPAuthenticationBackendConfiguration{ + URL: "ldap://127.0.0.1:389", + User: "cn=admin,dc=example,dc=com", + Password: "password", + UsersFilter: "uid={0}", + AdditionalUsersDN: "ou=users", + BaseDN: "dc=example,dc=com", + }, mockFactory) + + mockFactory.EXPECT(). + Dial(gomock.Eq("tcp"), gomock.Eq("127.0.0.1:389")). + Return(mockConn, nil) + + mockConn.EXPECT(). + Bind(gomock.Eq("cn=admin,dc=example,dc=com"), gomock.Eq("password")). + Return(nil) + + mockConn.EXPECT(). + Close() + + mockConn.EXPECT(). + // Here we ensure that the input has been correctly escaped. + Search(NewSearchRequestMatcher("uid=john\\=abc")). + Return(&ldap.SearchResult{}, nil) + + ldapClient.getUserAttribute(mockConn, "john=abc", "dn") +}