diff --git a/internal/commands/util.go b/internal/commands/util.go index 4950c271b..bab9c70c9 100644 --- a/internal/commands/util.go +++ b/internal/commands/util.go @@ -215,13 +215,14 @@ const ( func loadXEnvCLIConfigValues(cmd *cobra.Command) (configs []string, filters []configuration.FileFilter, err error) { var ( filterNames []string + result XEnvCLIResult ) - if configs, _, err = loadXEnvCLIStringSliceValue(cmd, cmdFlagEnvNameConfig, cmdFlagNameConfig); err != nil { + if configs, result, err = loadXEnvCLIStringSliceValue(cmd, cmdFlagEnvNameConfig, cmdFlagNameConfig); err != nil { return nil, nil, err } - if configs, err = loadXNormalizedPaths(configs); err != nil { + if configs, err = loadXNormalizedPaths(configs, result); err != nil { return nil, nil, err } @@ -236,7 +237,7 @@ func loadXEnvCLIConfigValues(cmd *cobra.Command) (configs []string, filters []co return } -func loadXNormalizedPaths(paths []string) ([]string, error) { +func loadXNormalizedPaths(paths []string, result XEnvCLIResult) ([]string, error) { var ( configs, files, dirs []string err error @@ -258,10 +259,15 @@ func loadXNormalizedPaths(paths []string) ([]string, error) { files = append(files, path) default: if os.IsNotExist(err) { - configs = append(configs, path) - files = append(files, path) + switch result { + case XEnvCLIResultCLIImplicit: + continue + default: + configs = append(configs, path) + files = append(files, path) - continue + continue + } } return nil, fmt.Errorf("error occurred stating file at path '%s': %w", path, err) diff --git a/internal/commands/util_test.go b/internal/commands/util_test.go index 1aa6aaf83..f2306558b 100644 --- a/internal/commands/util_test.go +++ b/internal/commands/util_test.go @@ -96,6 +96,7 @@ func TestLoadXNormalizedPaths(t *testing.T) { ayml := filepath.Join(configdir, "a.yml") byml := filepath.Join(configdir, "b.yml") cyml := filepath.Join(otherdir, "c.yml") + dyml := filepath.Join(otherdir, "d.yml") file, err = os.Create(ayml) @@ -142,30 +143,44 @@ func TestLoadXNormalizedPaths(t *testing.T) { testCases := []struct { name string + haveX XEnvCLIResult have, expected []string expectedErr string }{ {"ShouldAllowFiles", + XEnvCLIResultCLIImplicit, []string{ayml}, []string{ayml}, - []string{ayml}, "", + "", + }, + {"ShouldSkipFilesNotExistImplicit", + XEnvCLIResultCLIImplicit, []string{dyml}, + []string(nil), + "", + }, + {"ShouldNotErrFilesNotExistExplicit", + XEnvCLIResultCLIExplicit, []string{dyml}, + []string{dyml}, + "", }, {"ShouldAllowDirectories", + XEnvCLIResultCLIImplicit, []string{configdir}, []string{configdir}, - []string{configdir}, "", + "", }, {"ShouldAllowFilesDirectories", + XEnvCLIResultCLIImplicit, []string{ayml, otherdir}, []string{ayml, otherdir}, - []string{ayml, otherdir}, "", + "", }, {"ShouldRaiseErrOnOverlappingFilesDirectories", - []string{ayml, configdir}, + XEnvCLIResultCLIImplicit, []string{ayml, configdir}, nil, fmt.Sprintf("failed to load config directory '%s': the config file '%s' is in that directory which is not supported", configdir, ayml), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - actual, actualErr := loadXNormalizedPaths(tc.have) + actual, actualErr := loadXNormalizedPaths(tc.have, tc.haveX) assert.Equal(t, tc.expected, actual) diff --git a/internal/handlers/handler_oidc_authorization.go b/internal/handlers/handler_oidc_authorization.go index 29147a0de..5cb193920 100644 --- a/internal/handlers/handler_oidc_authorization.go +++ b/internal/handlers/handler_oidc_authorization.go @@ -63,6 +63,28 @@ func OpenIDConnectAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWr return } + if !oidc.IsPushedAuthorizedRequest(requester, ctx.Providers.OpenIDConnect.GetPushedAuthorizeRequestURIPrefix(ctx)) { + if err = client.ValidatePKCEPolicy(requester); err != nil { + rfc := fosite.ErrorToRFC6749Error(err) + + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), client.GetID(), rfc.WithExposeDebug(true).GetDescription()) + + ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err) + + return + } + + if err = client.ValidateResponseModePolicy(requester); err != nil { + rfc := fosite.ErrorToRFC6749Error(err) + + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' failed to validate the Response Mode: %s", requester.GetID(), client.GetID(), rfc.WithExposeDebug(true).GetDescription()) + + ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err) + + return + } + } + if err = client.ValidatePKCEPolicy(requester); err != nil { rfc := fosite.ErrorToRFC6749Error(err) @@ -175,9 +197,19 @@ func OpenIDConnectPushedAuthorizationRequest(ctx *middlewares.AutheliaCtx, rw ht if err = client.ValidatePKCEPolicy(requester); err != nil { rfc := fosite.ErrorToRFC6749Error(err) - ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), clientID, rfc.WithExposeDebug(true).GetDescription()) + ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' failed to validate the PKCE policy: %s", requester.GetID(), client.GetID(), rfc.WithExposeDebug(true).GetDescription()) - ctx.Providers.OpenIDConnect.WritePushedAuthorizeError(ctx, rw, requester, err) + ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err) + + return + } + + if err = client.ValidateResponseModePolicy(requester); err != nil { + rfc := fosite.ErrorToRFC6749Error(err) + + ctx.Logger.Errorf("Pushed Authorization Request with id '%s' on client with id '%s' failed to validate the Response Mode: %s", requester.GetID(), client.GetID(), rfc.WithExposeDebug(true).GetDescription()) + + ctx.Providers.OpenIDConnect.WriteAuthorizeError(ctx, rw, requester, err) return } diff --git a/internal/model/totp_configuration.go b/internal/model/totp_configuration.go index a262ddb7b..0d325ce0a 100644 --- a/internal/model/totp_configuration.go +++ b/internal/model/totp_configuration.go @@ -25,9 +25,12 @@ type TOTPConfiguration struct { Secret []byte `db:"secret" json:"-"` } +// LastUsed provides LastUsedAt as a *time.Time instead of sql.NullTime. func (c *TOTPConfiguration) LastUsed() *time.Time { if c.LastUsedAt.Valid { - return &c.LastUsedAt.Time + value := time.Unix(c.LastUsedAt.Time.Unix(), int64(c.LastUsedAt.Time.Nanosecond())) + + return &value } return nil @@ -73,9 +76,9 @@ func (c *TOTPConfiguration) Image(width, height int) (img image.Image, err error return key.Image(width, height) } -// MarshalYAML marshals this model into YAML. -func (c *TOTPConfiguration) MarshalYAML() (any, error) { - o := TOTPConfigurationData{ +// ToData converts this TOTPConfiguration into the data format for exporting etc. +func (c *TOTPConfiguration) ToData() TOTPConfigurationData { + return TOTPConfigurationData{ CreatedAt: c.CreatedAt, LastUsedAt: c.LastUsed(), Username: c.Username, @@ -85,8 +88,11 @@ func (c *TOTPConfiguration) MarshalYAML() (any, error) { Period: c.Period, Secret: base64.StdEncoding.EncodeToString(c.Secret), } +} - return yaml.Marshal(o) +// MarshalYAML marshals this model into YAML. +func (c *TOTPConfiguration) MarshalYAML() (any, error) { + return c.ToData(), nil } // UnmarshalYAML unmarshalls YAML into this model. @@ -127,7 +133,30 @@ type TOTPConfigurationData struct { Secret string `yaml:"secret"` } +// TOTPConfigurationDataExport represents a TOTPConfiguration export file. +type TOTPConfigurationDataExport struct { + TOTPConfigurations []TOTPConfigurationData `yaml:"totp_configurations"` +} + // TOTPConfigurationExport represents a TOTPConfiguration export file. type TOTPConfigurationExport struct { TOTPConfigurations []TOTPConfiguration `yaml:"totp_configurations"` } + +// ToData converts this TOTPConfigurationExport into a TOTPConfigurationDataExport. +func (export TOTPConfigurationExport) ToData() TOTPConfigurationDataExport { + data := TOTPConfigurationDataExport{ + TOTPConfigurations: make([]TOTPConfigurationData, len(export.TOTPConfigurations)), + } + + for i, config := range export.TOTPConfigurations { + data.TOTPConfigurations[i] = config.ToData() + } + + return data +} + +// MarshalYAML marshals this model into YAML. +func (export TOTPConfigurationExport) MarshalYAML() (any, error) { + return export.ToData(), nil +} diff --git a/internal/model/totp_configuration_test.go b/internal/model/totp_configuration_test.go index 81d82aed3..07c38a1f7 100644 --- a/internal/model/totp_configuration_test.go +++ b/internal/model/totp_configuration_test.go @@ -1,11 +1,14 @@ package model import ( + "database/sql" "encoding/json" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) /* @@ -75,3 +78,62 @@ func TestShouldReturnImage(t *testing.T) { assert.Equal(t, 41, img.Bounds().Dx()) assert.Equal(t, 41, img.Bounds().Dy()) } + +func TestTOTPConfigurationImportExport(t *testing.T) { + have := TOTPConfigurationExport{ + TOTPConfigurations: []TOTPConfiguration{ + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Valid: false}, + Username: "john", + Issuer: "example", + Algorithm: "SHA1", + Digits: 6, + Period: 30, + Secret: MustRead(80), + }, + { + ID: 1, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true}, + Username: "abc", + Issuer: "example2", + Algorithm: "SHA512", + Digits: 8, + Period: 90, + Secret: MustRead(120), + }, + }, + } + + out, err := yaml.Marshal(&have) + require.NoError(t, err) + + imported := TOTPConfigurationExport{} + + require.NoError(t, yaml.Unmarshal(out, &imported)) + + require.Equal(t, len(have.TOTPConfigurations), len(imported.TOTPConfigurations)) + + for i, actual := range imported.TOTPConfigurations { + t.Run(actual.Username, func(t *testing.T) { + expected := have.TOTPConfigurations[i] + + if expected.ID != 0 { + assert.NotEqual(t, expected.ID, actual.ID) + } else { + assert.Equal(t, expected.ID, actual.ID) + } + + assert.Equal(t, expected.Username, actual.Username) + assert.Equal(t, expected.Issuer, actual.Issuer) + assert.Equal(t, expected.Algorithm, actual.Algorithm) + assert.Equal(t, expected.Digits, actual.Digits) + assert.Equal(t, expected.Period, actual.Period) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second) + assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid) + }) + } +} diff --git a/internal/model/types_test.go b/internal/model/types_test.go index 50d8bf38c..ada1e9bbb 100644 --- a/internal/model/types_test.go +++ b/internal/model/types_test.go @@ -1,21 +1,11 @@ package model import ( - "fmt" "testing" - "github.com/ory/fosite" "github.com/stretchr/testify/assert" ) -func Test(t *testing.T) { - args := fosite.Arguments{"abc", "123"} - - x := StringSlicePipeDelimited(args) - - fmt.Println(x) -} - func TestDatabaseModelTypeIP(t *testing.T) { ip := IP{} diff --git a/internal/model/webauthn.go b/internal/model/webauthn.go index 6866ea5a7..bd5b662a5 100644 --- a/internal/model/webauthn.go +++ b/internal/model/webauthn.go @@ -196,14 +196,18 @@ func (d *WebAuthnDevice) UpdateSignInInfo(config *webauthn.Config, now time.Time } } +// DataValueLastUsedAt provides LastUsedAt as a *time.Time instead of sql.NullTime. func (d *WebAuthnDevice) DataValueLastUsedAt() *time.Time { if d.LastUsedAt.Valid { - return &d.LastUsedAt.Time + value := time.Unix(d.LastUsedAt.Time.Unix(), int64(d.LastUsedAt.Time.Nanosecond())) + + return &value } return nil } +// DataValueAAGUID provides AAGUID as a *string instead of uuid.NullUUID. func (d *WebAuthnDevice) DataValueAAGUID() *string { if d.AAGUID.Valid { value := d.AAGUID.UUID.String() @@ -382,3 +386,26 @@ func (d *WebAuthnDeviceData) ToDevice() (device *WebAuthnDevice, err error) { type WebAuthnDeviceExport struct { WebAuthnDevices []WebAuthnDevice `yaml:"webauthn_devices"` } + +// WebAuthnDeviceDataExport represents a WebAuthnDevice export file. +type WebAuthnDeviceDataExport struct { + WebAuthnDevices []WebAuthnDeviceData `yaml:"webauthn_devices"` +} + +// ToData converts this WebAuthnDeviceExport into a WebAuthnDeviceDataExport. +func (export WebAuthnDeviceExport) ToData() WebAuthnDeviceDataExport { + data := WebAuthnDeviceDataExport{ + WebAuthnDevices: make([]WebAuthnDeviceData, len(export.WebAuthnDevices)), + } + + for i, device := range export.WebAuthnDevices { + data.WebAuthnDevices[i] = device.ToData() + } + + return data +} + +// MarshalYAML marshals this model into YAML. +func (export WebAuthnDeviceExport) MarshalYAML() (any, error) { + return export.ToData(), nil +} diff --git a/internal/model/webauthn_test.go b/internal/model/webauthn_test.go new file mode 100644 index 000000000..654c1a178 --- /dev/null +++ b/internal/model/webauthn_test.go @@ -0,0 +1,88 @@ +package model + +import ( + "crypto/rand" + "database/sql" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestWebAuthnDeviceImportExport(t *testing.T) { + have := WebAuthnDeviceExport{ + WebAuthnDevices: []WebAuthnDevice{ + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Time: time.Now(), Valid: true}, + RPID: "example", + Username: "john", + Description: "akey", + KID: NewBase64(MustRead(20)), + PublicKey: MustRead(128), + AttestationType: "fido-u2f", + Transport: "", + AAGUID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + SignCount: 20, + CloneWarning: false, + }, + { + ID: 0, + CreatedAt: time.Now(), + LastUsedAt: sql.NullTime{Valid: false}, + RPID: "example2", + Username: "john2", + Description: "bkey", + KID: NewBase64(MustRead(60)), + PublicKey: MustRead(64), + AttestationType: "packed", + Transport: "", + AAGUID: uuid.NullUUID{Valid: false}, + SignCount: 30, + CloneWarning: true, + }, + }, + } + + out, err := yaml.Marshal(&have) + require.NoError(t, err) + + imported := WebAuthnDeviceExport{} + + require.NoError(t, yaml.Unmarshal(out, &imported)) + require.Equal(t, len(have.WebAuthnDevices), len(imported.WebAuthnDevices)) + + for i, actual := range imported.WebAuthnDevices { + t.Run(actual.Description, func(t *testing.T) { + expected := have.WebAuthnDevices[i] + + assert.Equal(t, expected.KID, actual.KID) + assert.Equal(t, expected.PublicKey, actual.PublicKey) + assert.Equal(t, expected.SignCount, actual.SignCount) + assert.Equal(t, expected.AttestationType, actual.AttestationType) + assert.Equal(t, expected.RPID, actual.RPID) + assert.Equal(t, expected.AAGUID.Valid, actual.AAGUID.Valid) + assert.Equal(t, expected.AAGUID.UUID, actual.AAGUID.UUID) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.LastUsedAt.Time, actual.LastUsedAt.Time, time.Second) + assert.Equal(t, expected.LastUsedAt.Valid, actual.LastUsedAt.Valid) + assert.Equal(t, expected.CloneWarning, actual.CloneWarning) + assert.Equal(t, expected.Description, actual.Description) + assert.Equal(t, expected.Username, actual.Username) + }) + } +} + +func MustRead(n int) []byte { + data := make([]byte, n) + + if _, err := rand.Read(data); err != nil { + panic(err) + } + + return data +} diff --git a/internal/oidc/client.go b/internal/oidc/client.go index 8388db9e7..41d302d03 100644 --- a/internal/oidc/client.go +++ b/internal/oidc/client.go @@ -1,8 +1,6 @@ package oidc import ( - "strings" - "github.com/ory/fosite" "github.com/ory/x/errorsx" @@ -30,7 +28,7 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client) RedirectURIs: config.RedirectURIs, GrantTypes: config.GrantTypes, ResponseTypes: config.ResponseTypes, - ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeDefault}, + ResponseModes: []fosite.ResponseModeType{}, EnforcePAR: config.EnforcePAR, @@ -73,21 +71,44 @@ func (c *Client) ValidatePKCEPolicy(r fosite.Requester) (err error) { // ValidatePARPolicy is a helper function to validate additional policy constraints on a per-client basis. func (c *Client) ValidatePARPolicy(r fosite.Requester, prefix string) (err error) { - form := r.GetRequestForm() - if c.EnforcePAR { - if requestURI := form.Get(FormParameterRequestURI); !strings.HasPrefix(requestURI, prefix) { - if requestURI == "" { + if !IsPushedAuthorizedRequest(r, prefix) { + switch requestURI := r.GetRequestForm().Get(FormParameterRequestURI); requestURI { + case "": return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebug("The request_uri parameter was empty.")) + default: + return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebugf("The request_uri parameter '%s' is malformed.", requestURI)) } - - return errorsx.WithStack(ErrPAREnforcedClientMissingPAR.WithDebugf("The request_uri parameter '%s' is malformed.", requestURI)) } } return nil } +// ValidateResponseModePolicy is an additional check to the response mode parameter to ensure if it's omitted that the +// default response mode for the fosite.AuthorizeRequester is permitted. +func (c *Client) ValidateResponseModePolicy(r fosite.AuthorizeRequester) (err error) { + if r.GetResponseMode() != fosite.ResponseModeDefault { + return nil + } + + m := r.GetDefaultResponseMode() + + modes := c.GetResponseModes() + + if len(modes) == 0 { + return nil + } + + for _, mode := range modes { + if m == mode { + return nil + } + } + + return errorsx.WithStack(fosite.ErrUnsupportedResponseMode.WithHintf(`The request omitted the response_mode making the default response_mode "%s" based on the other authorization request parameters but registered OAuth 2.0 client doesn't support this response_mode`, m)) +} + // IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient. func (c *Client) IsAuthenticationLevelSufficient(level authentication.Level) bool { if level == authentication.NotAuthenticated { diff --git a/internal/oidc/client_test.go b/internal/oidc/client_test.go index 747f1e8c2..1546644c2 100644 --- a/internal/oidc/client_test.go +++ b/internal/oidc/client_test.go @@ -1,6 +1,7 @@ package oidc import ( + "fmt" "testing" "github.com/ory/fosite" @@ -19,8 +20,7 @@ func TestNewClient(t *testing.T) { assert.Equal(t, "", blankClient.ID) assert.Equal(t, "", blankClient.Description) assert.Equal(t, "", blankClient.Description) - require.Len(t, blankClient.ResponseModes, 1) - assert.Equal(t, fosite.ResponseModeDefault, blankClient.ResponseModes[0]) + assert.Len(t, blankClient.ResponseModes, 0) exampleConfig := schema.OpenIDConnectClientConfiguration{ ID: "myapp", @@ -36,11 +36,10 @@ func TestNewClient(t *testing.T) { exampleClient := NewClient(exampleConfig) assert.Equal(t, "myapp", exampleClient.ID) - require.Len(t, exampleClient.ResponseModes, 4) - assert.Equal(t, fosite.ResponseModeDefault, exampleClient.ResponseModes[0]) - assert.Equal(t, fosite.ResponseModeFormPost, exampleClient.ResponseModes[1]) - assert.Equal(t, fosite.ResponseModeQuery, exampleClient.ResponseModes[2]) - assert.Equal(t, fosite.ResponseModeFragment, exampleClient.ResponseModes[3]) + require.Len(t, exampleClient.ResponseModes, 3) + assert.Equal(t, fosite.ResponseModeFormPost, exampleClient.ResponseModes[0]) + assert.Equal(t, fosite.ResponseModeQuery, exampleClient.ResponseModes[1]) + assert.Equal(t, fosite.ResponseModeFragment, exampleClient.ResponseModes[2]) assert.Equal(t, authorization.TwoFactor, exampleClient.Policy) } @@ -226,6 +225,7 @@ func TestNewClientPKCE(t *testing.T) { expected string r *fosite.Request err string + desc string }{ { "ShouldNotEnforcePKCEAndNotErrorOnNonPKCERequest", @@ -235,6 +235,7 @@ func TestNewClientPKCE(t *testing.T) { "", &fosite.Request{}, "", + "", }, { "ShouldEnforcePKCEAndErrorOnNonPKCERequest", @@ -244,6 +245,7 @@ func TestNewClientPKCE(t *testing.T) { "", &fosite.Request{}, "invalid_request", + "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. Clients must include a code_challenge when performing the authorize code flow, but it is missing. The server is configured in a way that enforces PKCE for this client.", }, { "ShouldEnforcePKCEAndNotErrorOnPKCERequest", @@ -253,6 +255,7 @@ func TestNewClientPKCE(t *testing.T) { "", &fosite.Request{Form: map[string][]string{"code_challenge": {"abc"}}}, "", + "", }, {"ShouldEnforcePKCEFromChallengeMethodAndErrorOnNonPKCERequest", schema.OpenIDConnectClientConfiguration{PKCEChallengeMethod: "S256"}, @@ -261,6 +264,7 @@ func TestNewClientPKCE(t *testing.T) { "S256", &fosite.Request{}, "invalid_request", + "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. Clients must include a code_challenge when performing the authorize code flow, but it is missing. The server is configured in a way that enforces PKCE for this client.", }, {"ShouldEnforcePKCEFromChallengeMethodAndErrorOnInvalidChallengeMethod", schema.OpenIDConnectClientConfiguration{PKCEChallengeMethod: "S256"}, @@ -269,6 +273,7 @@ func TestNewClientPKCE(t *testing.T) { "S256", &fosite.Request{Form: map[string][]string{"code_challenge": {"abc"}}}, "invalid_request", + "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. Client must use code_challenge_method=S256, is not allowed. The server is configured in a way that enforces PKCE S256 as challenge method for this client.", }, {"ShouldEnforcePKCEFromChallengeMethodAndNotErrorOnValidRequest", schema.OpenIDConnectClientConfiguration{PKCEChallengeMethod: "S256"}, @@ -277,6 +282,7 @@ func TestNewClientPKCE(t *testing.T) { "S256", &fosite.Request{Form: map[string][]string{"code_challenge": {"abc"}, "code_challenge_method": {"S256"}}}, "", + "", }, } @@ -292,7 +298,136 @@ func TestNewClientPKCE(t *testing.T) { err := client.ValidatePKCEPolicy(tc.r) if tc.err != "" { + require.NotNil(t, err) assert.EqualError(t, err, tc.err) + assert.Equal(t, tc.desc, fosite.ErrorToRFC6749Error(err).WithExposeDebug(true).GetDescription()) + } else { + assert.NoError(t, err) + } + } + }) + } +} + +func TestNewClientPAR(t *testing.T) { + testCases := []struct { + name string + have schema.OpenIDConnectClientConfiguration + expected bool + r *fosite.Request + err string + desc string + }{ + { + "ShouldNotEnforcEPARAndNotErrorOnNonPARRequest", + schema.OpenIDConnectClientConfiguration{}, + false, + &fosite.Request{}, + "", + "", + }, + { + "ShouldEnforcePARAndErrorOnNonPARRequest", + schema.OpenIDConnectClientConfiguration{EnforcePAR: true}, + true, + &fosite.Request{}, + "invalid_request", + "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. Pushed Authorization Requests are enforced for this client but no such request was sent. The request_uri parameter was empty.", + }, + { + "ShouldEnforcePARAndErrorOnNonPARRequest", + schema.OpenIDConnectClientConfiguration{EnforcePAR: true}, + true, + &fosite.Request{Form: map[string][]string{FormParameterRequestURI: {"https://example.com"}}}, + "invalid_request", + "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed. Pushed Authorization Requests are enforced for this client but no such request was sent. The request_uri parameter 'https://example.com' is malformed."}, + { + "ShouldEnforcePARAndNotErrorOnPARRequest", + schema.OpenIDConnectClientConfiguration{EnforcePAR: true}, + true, + &fosite.Request{Form: map[string][]string{FormParameterRequestURI: {fmt.Sprintf("%sabc", urnPARPrefix)}}}, + "", + "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := NewClient(tc.have) + + assert.Equal(t, tc.expected, client.EnforcePAR) + + if tc.r != nil { + err := client.ValidatePARPolicy(tc.r, urnPARPrefix) + + if tc.err != "" { + require.NotNil(t, err) + assert.EqualError(t, err, tc.err) + assert.Equal(t, tc.desc, fosite.ErrorToRFC6749Error(err).WithExposeDebug(true).GetDescription()) + } else { + assert.NoError(t, err) + } + } + }) + } +} + +func TestNewClientResponseModes(t *testing.T) { + testCases := []struct { + name string + have schema.OpenIDConnectClientConfiguration + expected []fosite.ResponseModeType + r *fosite.AuthorizeRequest + err string + desc string + }{ + { + "ShouldEnforceResponseModePolicyAndAllowDefaultModeQuery", + schema.OpenIDConnectClientConfiguration{ResponseModes: []string{ResponseModeQuery}}, + []fosite.ResponseModeType{fosite.ResponseModeQuery}, + &fosite.AuthorizeRequest{DefaultResponseMode: fosite.ResponseModeQuery, ResponseMode: fosite.ResponseModeDefault, Request: fosite.Request{Form: map[string][]string{FormParameterResponseMode: nil}}}, + "", + "", + }, + { + "ShouldEnforceResponseModePolicyAndFailOnDefaultMode", + schema.OpenIDConnectClientConfiguration{ResponseModes: []string{ResponseModeFormPost}}, + []fosite.ResponseModeType{fosite.ResponseModeFormPost}, + &fosite.AuthorizeRequest{DefaultResponseMode: fosite.ResponseModeQuery, ResponseMode: fosite.ResponseModeDefault, Request: fosite.Request{Form: map[string][]string{FormParameterResponseMode: nil}}}, + "unsupported_response_mode", + "The authorization server does not support obtaining a response using this response mode. The request omitted the response_mode making the default response_mode 'query' based on the other authorization request parameters but registered OAuth 2.0 client doesn't support this response_mode", + }, + { + "ShouldNotEnforceConfiguredResponseMode", + schema.OpenIDConnectClientConfiguration{ResponseModes: []string{ResponseModeFormPost}}, + []fosite.ResponseModeType{fosite.ResponseModeFormPost}, + &fosite.AuthorizeRequest{DefaultResponseMode: fosite.ResponseModeQuery, ResponseMode: fosite.ResponseModeQuery, Request: fosite.Request{Form: map[string][]string{FormParameterResponseMode: {ResponseModeQuery}}}}, + "", + "", + }, + { + "ShouldNotEnforceUnconfiguredResponseMode", + schema.OpenIDConnectClientConfiguration{ResponseModes: []string{}}, + []fosite.ResponseModeType{}, + &fosite.AuthorizeRequest{DefaultResponseMode: fosite.ResponseModeQuery, ResponseMode: fosite.ResponseModeDefault, Request: fosite.Request{Form: map[string][]string{FormParameterResponseMode: {ResponseModeQuery}}}}, + "", + "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := NewClient(tc.have) + + assert.Equal(t, tc.expected, client.GetResponseModes()) + + if tc.r != nil { + err := client.ValidateResponseModePolicy(tc.r) + + if tc.err != "" { + require.NotNil(t, err) + assert.EqualError(t, err, tc.err) + assert.Equal(t, tc.desc, fosite.ErrorToRFC6749Error(err).WithExposeDebug(true).GetDescription()) } else { assert.NoError(t, err) } diff --git a/internal/oidc/const.go b/internal/oidc/const.go index f7484e958..54c6b2ff4 100644 --- a/internal/oidc/const.go +++ b/internal/oidc/const.go @@ -112,6 +112,7 @@ const ( const ( FormParameterRequestURI = "request_uri" + FormParameterResponseMode = "response_mode" FormParameterCodeChallenge = "code_challenge" FormParameterCodeChallengeMethod = "code_challenge_method" ) diff --git a/internal/oidc/util.go b/internal/oidc/util.go new file mode 100644 index 000000000..2e7c14d37 --- /dev/null +++ b/internal/oidc/util.go @@ -0,0 +1,12 @@ +package oidc + +import ( + "strings" + + "github.com/ory/fosite" +) + +// IsPushedAuthorizedRequest returns true if the requester has a PushedAuthorizationRequest redirect_uri value. +func IsPushedAuthorizedRequest(r fosite.Requester, prefix string) bool { + return strings.HasPrefix(r.GetRequestForm().Get(FormParameterRequestURI), prefix) +} diff --git a/internal/templates/src/oidc/AuthorizeResponseFormPost.html b/internal/templates/src/oidc/AuthorizeResponseFormPost.html index 93569b59b..32c6911a6 100644 --- a/internal/templates/src/oidc/AuthorizeResponseFormPost.html +++ b/internal/templates/src/oidc/AuthorizeResponseFormPost.html @@ -10,10 +10,10 @@