diff --git a/e2e/keycloak/docker-compose.yaml b/e2e/keycloak/docker-compose.yaml index 09ee671..d98edfb 100644 --- a/e2e/keycloak/docker-compose.yaml +++ b/e2e/keycloak/docker-compose.yaml @@ -73,7 +73,7 @@ services: test: /opt/keycloak/bin/kcadm.sh get realms/master --server http://localhost:8080 --realm master --user admin --password admin interval: 5s timeout: 2s - retries: 10 + retries: 30 start_period: 5s extra_hosts: # Required when running on Linux - "host.docker.internal:host-gateway" diff --git a/e2e/keycloak/keycloak_test.go b/e2e/keycloak/keycloak_test.go index 656190d..697ec1d 100644 --- a/e2e/keycloak/keycloak_test.go +++ b/e2e/keycloak/keycloak_test.go @@ -20,6 +20,7 @@ import ( "net" "net/http" "testing" + "time" "github.com/stretchr/testify/require" @@ -74,3 +75,57 @@ func TestOIDC(t *testing.T) { require.Equal(t, http.StatusOK, res.StatusCode) require.Contains(t, string(body), "Access allowed") } + +func TestOIDCRefreshTokens(t *testing.T) { + skipIfDockerHostNonResolvable(t) + + // Initialize the test OIDC client that will keep track of the state of the OIDC login process + client, err := common.NewOIDCTestClient( + common.WithCustomCA(testCAFile), + common.WithLoggingOptions(t.Log, true), + ) + require.NoError(t, err) + + // Send a request to the test server. It will be redirected to the IdP login page + res, err := client.Get(testURL) + require.NoError(t, err) + + // Parse the response body to get the URL where the login page would post the user-entered credentials + require.NoError(t, client.ParseLoginForm(res.Body, keyCloakLoginFormID)) + + // Submit the login form to the IdP. This will authenticate and redirect back to the application + res, err = client.Login(map[string]string{"username": username, "password": password, "credentialId": ""}) + require.NoError(t, err) + + // Verify that we get the expected response from the application + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, res.StatusCode) + require.Contains(t, string(body), "Access allowed") + + // Access tokens should expire in 10 seconds (tried with 5, but keycloak setup fails) + // Let's perform a request now and after 10 seconds to verify that the access token is refreshed + + t.Run("request with same tokens", func(t *testing.T) { + res, err = client.Get(testURL) + require.NoError(t, err) + + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, res.StatusCode) + require.Contains(t, string(body), "Access allowed") + }) + + t.Log("waiting for access token to expire...") + time.Sleep(10 * time.Second) + + t.Run("request with expired tokens", func(t *testing.T) { + res, err = client.Get(testURL) + require.NoError(t, err) + + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, res.StatusCode) + require.Contains(t, string(body), "Access allowed") + }) +} diff --git a/e2e/keycloak/setup-keycloak.sh b/e2e/keycloak/setup-keycloak.sh index fb8e245..7fc46af 100755 --- a/e2e/keycloak/setup-keycloak.sh +++ b/e2e/keycloak/setup-keycloak.sh @@ -24,6 +24,13 @@ REDIRECT_URL=https://host.docker.internal:8443/callback set -ex +/opt/keycloak/bin/kcadm.sh update realms/${REALM} \ + -s accessTokenLifespan=10 \ + --realm "${REALM}" \ + --server "${KEYCLOAK_SERVER}" \ + --user "${KEYCLOAK_ADMIN}" \ + --password "${KEYCLOAK_ADMIN_PASSWORD}" + /opt/keycloak/bin/kcadm.sh create users \ -s username="${USERNAME}" \ -s enabled=true \ diff --git a/e2e/suite.mk b/e2e/suite.mk index 6cd81b7..18c0dbb 100644 --- a/e2e/suite.mk +++ b/e2e/suite.mk @@ -18,7 +18,8 @@ ROOT := $(shell git rev-parse --show-toplevel) -include $(ROOT)/env.mk +include $(ROOT)/env.mk # Load common variables +-include $(ROOT)/.makerc # Pick up any local overrides. # Force run of the e2e tests by default E2E_TEST_OPTS ?= -count=1 diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index 57402d1..96b46cf 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -205,7 +205,20 @@ func (o *oidcHandler) Process(ctx context.Context, req *envoy.CheckRequest, resp // token_response. If successful, allow the request to proceed. If // unsuccessful, redirect for login. log.Debug("attempting token refresh") - // TODO (sergicastro): Handle token refresh + refreshedTokens := o.refreshToken(ctx, log, tokenResponse, tokenResponse.RefreshToken, sessionID) + if refreshedTokens == nil { + log.Info("Token refresh failed. Sending user to re-authenticate.") + o.redirectToIDP(ctx, log, resp, req.GetAttributes().GetRequest().GetHttp(), sessionID) + return nil + } + if err := store.SetTokenResponse(ctx, sessionID, refreshedTokens); err != nil { + log.Error("error saving refreshed tokens to session store", err) + setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) + return nil + } + + log.Info("Token refresh successful. Allowing request to proceed.") + o.allowResponse(resp, refreshedTokens) return nil } @@ -247,13 +260,14 @@ func (o *oidcHandler) redirectToIDP(ctx context.Context, log telemetry.Logger, } // Generate the redirect URL - query := url.Values{} - query.Add("response_type", "code") - query.Add("client_id", o.config.GetClientId()) - query.Add("redirect_uri", o.config.GetCallbackUri()) - query.Add("scope", strings.Join(o.config.GetScopes(), " ")) - query.Add("state", state) - query.Add("nonce", nonce) + query := url.Values{ + "response_type": []string{"code"}, + "client_id": []string{o.config.GetClientId()}, + "redirect_uri": []string{o.config.GetCallbackUri()}, + "scope": []string{strings.Join(o.config.GetScopes(), " ")}, + "state": []string{state}, + "nonce": []string{nonce}, + } redirectURL := o.config.GetAuthorizationUri() + "?" + query.Encode() // Generate denied response with redirect headers @@ -273,6 +287,7 @@ func (o *oidcHandler) redirectToIDP(ctx context.Context, log telemetry.Logger, setDenyResponse(resp, deny, codes.Unauthenticated) } +// retrieveTokens retrieves the tokens from the Identity Provider and redirects the user back to the originally requested URL. func (o *oidcHandler) retrieveTokens(ctx context.Context, log telemetry.Logger, req *envoy.CheckRequest, resp *envoy.CheckResponse, sessionID string) { store := o.sessions.Get(o.config) @@ -327,65 +342,263 @@ func (o *oidcHandler) retrieveTokens(ctx context.Context, log telemetry.Logger, "redirect_uri": []string{o.config.GetCallbackUri()}, } - oidcReq, err := http.NewRequest("POST", o.config.GetTokenUri(), strings.NewReader(form.Encode())) - if err != nil { - log.Error("error creating tokens request to OIDC", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) + // build headers + headers := http.Header{ + inthttp.HeaderContentType: []string{inthttp.HeaderContentTypeFormURLEncoded}, + inthttp.HeaderAuthorization: []string{inthttp.BasicAuthHeader(o.config.GetClientId(), o.config.GetClientSecret())}, + } + + log.Info("performing request to retrieve new tokens") + bodyTokens, errCode := performIDPRequest(log, o.httpClient, o.config.GetTokenUri(), form, headers) + if errCode != codes.OK { + setDenyResponse(resp, newDenyResponse(), errCode) + return + } + + // validate IDP tokens response + if !isValidIDPNewTokensResponse(log, o.config, bodyTokens) { + setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) return } + // validate ID token + if ok, errCode := o.isValidIDToken(ctx, log, bodyTokens.IDToken, stateFromStore.Nonce, true); !ok { + setDenyResponse(resp, newDenyResponse(), errCode) + return + } + + if err := store.ClearAuthorizationState(ctx, sessionID); err != nil { + log.Error("error clearing authorization state", err) + setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) + return + } + + // Knock 5 seconds off the expiry time to take into account the time it may + // have taken to retrieve the token. + expiresIn := time.Duration(bodyTokens.ExpiresIn)*time.Second - 5 + accessTokenExpiration := o.clock.Now().Add(expiresIn) + + log.Debug("saving tokens to session store") + if err := store.SetTokenResponse(ctx, sessionID, &oidc.TokenResponse{ + IDToken: bodyTokens.IDToken, + AccessToken: bodyTokens.AccessToken, + RefreshToken: bodyTokens.RefreshToken, + AccessTokenExpiresAt: accessTokenExpiration, + }); err != nil { + log.Error("error saving tokens to session store", err) + setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) + return + } + log.Debug("tokens retrieved successfully") + + deny := newDenyResponse() + deny.Status = &typev3.HttpStatus{Code: typev3.StatusCode_Found} + deny.Headers = append(deny.Headers, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{Key: inthttp.HeaderLocation, Value: stateFromStore.RequestedURL}, + }) + setDenyResponse(resp, deny, codes.Unauthenticated) +} + +// refreshToken retrieves new tokens from the Identity Provider using the given refresh token. +func (o *oidcHandler) refreshToken(ctx context.Context, log telemetry.Logger, expiredTokens *oidc.TokenResponse, token, sessionID string) *oidc.TokenResponse { + store := o.sessions.Get(o.config) + + form := url.Values{ + "grant_type": []string{"refresh_token"}, + "refresh_token": []string{token}, + "client_id": []string{o.config.GetClientId()}, + "client_secret": []string{o.config.GetClientSecret()}, + // according to this link, omitting the `scope` param should return new + // tokens with the previously requested `scope` + // https://www.oauth.com/oauth2-servers/access-tokens/refreshing-access-tokens/ + } + // build headers - oidcReq.Header = http.Header{ - inthttp.HeaderContentType: []string{inthttp.HeaderContentTypeFormURLEncoded}, - inthttp.HeaderAuthorization: []string{inthttp.BasicAuthHeader(o.config.GetClientId(), o.config.GetClientSecret())}, + headers := http.Header{ + inthttp.HeaderContentType: []string{inthttp.HeaderContentTypeFormURLEncoded}, + } + + log.Info("performing request to refresh access token") + bodyTokens, errCode := performIDPRequest(log, o.httpClient, o.config.GetTokenUri(), form, headers) + + if errCode != codes.OK { + return nil + } + + // validate IDP tokens response + if !isValidIDPRefreshTokenResponse(log, bodyTokens) { + //setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) + return nil + } + + // merge the new tokens with the stored ones + newTokenResponse := &oidc.TokenResponse{} + + _, err := oidc.ParseToken(bodyTokens.IDToken) + if err != nil { + log.Error("error parsing new id token, using the old one", err) + newTokenResponse.IDToken = expiredTokens.IDToken + } else { + log.Debug("updating id token") + newTokenResponse.IDToken = bodyTokens.IDToken + } + + if bodyTokens.AccessToken != "" { + log.Debug("updating access token") + newTokenResponse.AccessToken = bodyTokens.AccessToken + } else { + newTokenResponse.AccessToken = expiredTokens.AccessToken + } + + if bodyTokens.RefreshToken != "" { + log.Debug("updating refresh token") + newTokenResponse.RefreshToken = bodyTokens.RefreshToken + } else { + newTokenResponse.RefreshToken = expiredTokens.RefreshToken + } + + if bodyTokens.ExpiresIn > 0 { + log.Debug("updating access token expiration") + // Knock 5 seconds off the expiry time to take into account the time it may + // have taken to retrieve the token. + expiresIn := time.Duration(bodyTokens.ExpiresIn)*time.Second - 5 + newTokenResponse.AccessTokenExpiresAt = o.clock.Now().Add(expiresIn) + } else { + newTokenResponse.AccessTokenExpiresAt = expiredTokens.AccessTokenExpiresAt + } + + stateFromStore, err := store.GetAuthorizationState(ctx, sessionID) + if err != nil { + log.Error("error retrieving authorization state from session store", err) + return nil + } + var expectedNonce string + if stateFromStore != nil { + expectedNonce = stateFromStore.Nonce + } + + // validate the id token + if ok, _ := o.isValidIDToken(context.Background(), log, newTokenResponse.IDToken, expectedNonce, false); !ok { + return nil + } + + return newTokenResponse +} + +// idpTokensResponse is the response from the Identity Provider when requesting tokens. +type idpTokensResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + DeviceSecret string `json:"device_secret"` +} + +// performIDPRequest performs a request to the Identity Provider to retrieve tokens. +func performIDPRequest(log telemetry.Logger, client *http.Client, uri string, form url.Values, headers http.Header) (*idpTokensResponse, codes.Code) { + oidcReq, err := http.NewRequest("POST", uri, strings.NewReader(form.Encode())) + if err != nil { + log.Error("error creating tokens request to OIDC", err) + return nil, codes.Internal } + oidcReq.Header = headers - oidcResp, err := o.httpClient.Do(oidcReq) + oidcResp, err := client.Do(oidcReq) if err != nil { log.Error("error performing tokens request to OIDC", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return + return nil, codes.Internal } if oidcResp.StatusCode != http.StatusOK { log.Info("OIDC server returned non-200 status code", "status-code", oidcResp.StatusCode, "url", oidcReq.URL.String()) - setDenyResponse(resp, newDenyResponse(), codes.Unknown) - return + return nil, codes.Unknown } respBody, err := io.ReadAll(oidcResp.Body) _ = oidcResp.Body.Close() if err != nil { log.Error("error reading tokens response", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return + return nil, codes.Internal } - bodyTokens := &tokensResponse{} + bodyTokens := &idpTokensResponse{} err = json.Unmarshal(respBody, &bodyTokens) if err != nil { log.Error("error unmarshalling tokens response", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return + return nil, codes.Internal + } + + return bodyTokens, codes.OK +} + +// isValidIDPNewTokensResponse checks if the response from the Identity Provider is valid according to the OpenID Connect specification. +// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse +func isValidIDPNewTokensResponse(log telemetry.Logger, config *oidcv1.OIDCConfig, tokenResponse *idpTokensResponse) bool { + // token_type must be Bearer + if tokenResponse.TokenType != "Bearer" { + log.Info("token type is not Bearer in token response", "token-type", tokenResponse.TokenType) + return false + } + + // expires_in must be a positive value + if tokenResponse.ExpiresIn < 0 { + log.Info("expires_in is not a positive value in token response", "expires-in", tokenResponse.ExpiresIn) + return false } - idToken, err := oidc.ParseToken(bodyTokens.IDToken) + // If access_token forwarding is configured but there is not an access token + // in the token response then there is a problem + if config.GetAccessToken() != nil && tokenResponse.AccessToken == "" { + log.Info("access token forwarding is configured but no access token was returned") + return false + } + + return true +} + +// isValidIDPRefreshTokenResponse checks if the response from the Identity Provider is valid according to the OpenID Connect specification. +// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse +func isValidIDPRefreshTokenResponse(log telemetry.Logger, tokenResponse *idpTokensResponse) bool { + // token_type must be Bearer + if tokenResponse.TokenType != "Bearer" { + log.Info("token type is not Bearer in token response", "token-type", tokenResponse.TokenType) + return false + } + + // expires_in must be a positive value + if tokenResponse.ExpiresIn < 0 { + log.Info("expires_in is not a positive value in token response", "expires-in", tokenResponse.ExpiresIn) + return false + } + + return true +} + +// isValidIDToken checks if the id token is valid according to the OpenID Connect specification. +// It checks the nonce, audience, and verifies the signature with the fetched jwks. +// It returns a boolean indicating if the token is valid and a code indicating the reason if it is not. +// If the nonce is not required, it will only check the expectedNonce against the token's nonce if it is present, as OIDC spec defines. +func (o *oidcHandler) isValidIDToken(ctx context.Context, log telemetry.Logger, idTokenString, expectedNonce string, isNonceRequired bool) (bool, codes.Code) { + idToken, err := oidc.ParseToken(idTokenString) if err != nil { log.Error("error parsing id token", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return + return false, codes.Internal } oidcNonce, ok := idToken.Get("nonce") - if !ok { + if !ok && isNonceRequired { log.Info("id token does not have nonce claim") - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return - } - if oidcNonce.(string) != stateFromStore.Nonce { - log.Info("id token nonce does not match", "nonce-from-id-token", oidcNonce, "nonce-from-store", stateFromStore.Nonce) - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return + return false, codes.InvalidArgument + } + if ok { + tokenNonce := oidcNonce.(string) + // if nonce is not required, both token and expected nonce must be present to perform the check + if (isNonceRequired || tokenNonce != "" && expectedNonce != "") && tokenNonce != expectedNonce { + log.Info("id token nonce does not match", "nonce-from-id-token", oidcNonce, "nonce-from-store", expectedNonce) + return false, codes.InvalidArgument + } } var audMatches bool @@ -397,86 +610,21 @@ func (o *oidcHandler) retrieveTokens(ctx context.Context, log telemetry.Logger, } if !audMatches { log.Info("id token audience does not match", "aud-from-id-token", idToken.Audience(), "aud-from-config", o.config.GetClientId()) - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return + return false, codes.InvalidArgument } jwtSet, err := o.jwks.Get(ctx, o.config) if err != nil { log.Error("error fetching jwks", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return + return false, codes.Internal } - if _, err := jws.VerifySet([]byte(bodyTokens.IDToken), jwtSet); err != nil { + if _, err := jws.VerifySet([]byte(idTokenString), jwtSet); err != nil { log.Error("error verifying id token with fetched jwks", err) - setDenyResponse(resp, newDenyResponse(), codes.Internal) - return - } - - // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - // token_type must be Bearer - if bodyTokens.TokenType != "Bearer" { - log.Info("token type is not Bearer in token response", "token-type", bodyTokens.TokenType) - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return + return false, codes.Internal } - // expires_in must be a positive value - if bodyTokens.ExpiresIn < 0 { - log.Info("expires_in is not a positive value in token response", "expires-in", bodyTokens.ExpiresIn) - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return - } - - // Knock 5 seconds off the expiry time to take into account the time it may - // have taken to retrieve the token. - expiresIn := time.Duration(bodyTokens.ExpiresIn)*time.Second - 5 - accessTokenExpiration := o.clock.Now().Add(expiresIn) - - // If access_token forwarding is configured but there is not an access token - // in the token response then there is a problem - if o.config.GetAccessToken() != nil && bodyTokens.AccessToken == "" { - log.Info("access token forwarding is configured but no access token was returned") - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return - } - - if err := store.ClearAuthorizationState(ctx, sessionID); err != nil { - log.Error("error clearing authorization state", err) - setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) - return - } - - log.Debug("saving tokens to session store") - if err := store.SetTokenResponse(ctx, sessionID, &oidc.TokenResponse{ - IDToken: bodyTokens.IDToken, - AccessToken: bodyTokens.AccessToken, - RefreshToken: bodyTokens.RefreshToken, - AccessTokenExpiresAt: accessTokenExpiration, - }); err != nil { - log.Error("error saving tokens to session store", err) - setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) - return - } - log.Debug("tokens retrieved successfully") - - deny := newDenyResponse() - deny.Status = &typev3.HttpStatus{Code: typev3.StatusCode_Found} - deny.Headers = append(deny.Headers, &corev3.HeaderValueOption{ - Header: &corev3.HeaderValue{Key: inthttp.HeaderLocation, Value: stateFromStore.RequestedURL}, - }) - setDenyResponse(resp, deny, codes.Unauthenticated) -} - -type tokensResponse struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - TokenType string `json:"token_type"` - Scope string `json:"scope"` - DeviceSecret string `json:"device_secret"` + return true, codes.OK } // newDenyResponse creates a new DeniedHttpResponse with the standard headers. diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go index 0718295..fdca6ba 100644 --- a/internal/authz/oidc_test.go +++ b/internal/authz/oidc_test.go @@ -141,17 +141,20 @@ var ( "token_endpoint": "http://idp-test-server/token", "jwks_uri": "http://idp-test-server/jwks" }` + + wantRedirectParams = url.Values{ + "response_type": {"code"}, + "client_id": {"test-client-id"}, + "redirect_uri": {"https://localhost:443/callback"}, + "scope": {"openid email"}, + "state": {newState}, + "nonce": {newNonce}, + } + + wantRedirectBaseURI = "http://idp-test-server/auth" ) func TestOIDCProcess(t *testing.T) { - wantRedirectParams := url.Values{} - wantRedirectParams.Add("response_type", "code") - wantRedirectParams.Add("client_id", "test-client-id") - wantRedirectParams.Add("redirect_uri", "https://localhost:443/callback") - wantRedirectParams.Add("scope", "openid email") - wantRedirectParams.Add("state", newState) - wantRedirectParams.Add("nonce", newNonce) - wantRedirectBaseURI := "http://idp-test-server/auth" unknownJWKPriv, _ := newKeyPair(t) jwkPriv, jwkPub := newKeyPair(t) @@ -212,7 +215,7 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "request with an existing sessionID expired", + name: "request with an existing sessionID expired with no refresh token", req: withSessionHeader, storedTokenResponse: &oidc.TokenResponse{ IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday)), @@ -276,7 +279,7 @@ func TestOIDCProcess(t *testing.T) { name string req *envoy.CheckRequest storedAuthState *oidc.AuthorizationState - mockTokensResponse *tokensResponse + mockTokensResponse *idpTokensResponse mockStatusCode int responseVerify func(*testing.T, *envoy.CheckResponse) }{ @@ -284,7 +287,7 @@ func TestOIDCProcess(t *testing.T) { name: "successfully retrieve new tokens", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ + mockTokensResponse: &idpTokensResponse{ IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), AccessToken: "access-token", TokenType: "Bearer", @@ -381,42 +384,42 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "idp server returns invalid JWT id-token", + name: "idp returned non-bearer token type", req: callbackRequest, storedAuthState: validAuthState, - mockStatusCode: http.StatusOK, - mockTokensResponse: &tokensResponse{ - IDToken: "not-a-jwt", + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), + TokenType: "not-bearer", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { - require.Equal(t, int32(codes.Internal), response.GetStatus().GetCode()) + require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) requireStandardResponseHeaders(t, response) requireStoredTokens(t, store, sessionID, false) }, }, { - name: "idp server returns JWT signed with unknown key", + name: "idp returned invalid expires_in for access token", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, unknownJWKPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), + TokenType: "Bearer", + ExpiresIn: -1, }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { - require.Equal(t, int32(codes.Internal), response.GetStatus().GetCode()) + require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) requireStandardResponseHeaders(t, response) requireStoredTokens(t, store, sessionID, false) }, }, { - name: "session nonce stored does idp returned nonce", - req: callbackRequest, - storedAuthState: &oidc.AuthorizationState{ - Nonce: "old-nonce", - State: newState, - RequestedURL: requestedAppURL, - }, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", "non-matching-nonce")), + name: "idp didn't return access token", + req: callbackRequest, + storedAuthState: validAuthState, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), + TokenType: "Bearer", + ExpiresIn: 3600, }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -425,24 +428,51 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "idp returned empty audience", + name: "idp server returns invalid JWT id-token", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce)), + mockStatusCode: http.StatusOK, + mockTokensResponse: &idpTokensResponse{ + IDToken: "not-a-jwt", + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { - require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) + require.Equal(t, int32(codes.Internal), response.GetStatus().GetCode()) requireStandardResponseHeaders(t, response) requireStoredTokens(t, store, sessionID, false) }, }, { - name: "idp returned non-matching audience", + name: "idp server returns JWT signed with unknown key", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"non-matching-audience"})), + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, unknownJWKPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", + }, + responseVerify: func(t *testing.T, response *envoy.CheckResponse) { + require.Equal(t, int32(codes.Internal), response.GetStatus().GetCode()) + requireStandardResponseHeaders(t, response) + requireStoredTokens(t, store, sessionID, false) + }, + }, + { + name: "idp didn't return nonce", + req: callbackRequest, + storedAuthState: &oidc.AuthorizationState{ + Nonce: "old-nonce", + State: newState, + RequestedURL: requestedAppURL, + }, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder()), + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -451,12 +481,18 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "idp returned non-bearer token type", - req: callbackRequest, - storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), - TokenType: "not-bearer", + name: "session nonce stored does not match idp returned nonce", + req: callbackRequest, + storedAuthState: &oidc.AuthorizationState{ + Nonce: "old-nonce", + State: newState, + RequestedURL: requestedAppURL, + }, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", "non-matching-nonce")), + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -465,13 +501,14 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "idp returned invalid expires_in for access token", + name: "idp returned empty audience", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), - TokenType: "Bearer", - ExpiresIn: -1, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce)), + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -480,13 +517,14 @@ func TestOIDCProcess(t *testing.T) { }, }, { - name: "idp didn't return access token", + name: "idp returned non-matching audience", req: callbackRequest, storedAuthState: validAuthState, - mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), - TokenType: "Bearer", - ExpiresIn: 3600, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"non-matching-audience"})), + TokenType: "Bearer", + ExpiresIn: 3600, + AccessToken: "access-token", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -520,6 +558,265 @@ func TestOIDCProcess(t *testing.T) { tt.responseVerify(t, resp) }) } + + validIDToken := newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)) + validIDTokenWithoutNonce := newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"})) + + expiredTokenResponse := &oidc.TokenResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday).Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + RefreshToken: "refresh-token", + AccessToken: "access-token", + AccessTokenExpiresAt: yesterday, + } + + refreshTokensTests := []struct { + name string + req *envoy.CheckRequest + storedAuthState *oidc.AuthorizationState + storedTokenResponse *oidc.TokenResponse + mockTokensResponse *idpTokensResponse + mockStatusCode int + responseVerify func(*testing.T, *envoy.CheckResponse) + }{ + { + name: "IDP server returns empty body", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns an non-200 status", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockStatusCode: http.StatusInternalServerError, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns response with an invalid token_type", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: validIDToken, + AccessToken: "access-token", + TokenType: "invalid-token-type", + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns a response with an invalid expires_at", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: validIDToken, + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: -1, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns a response with no access token - succeeds using the stored access token", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: validIDToken, + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) + require.NotNil(t, resp.GetOkResponse()) + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, validIDToken, "access-token") + requireStoredTokens(t, store, sessionID, true) + requireStoredTokens(t, store, newSessionID, false) + }, + }, + { + name: "IDP server doesn't return an id-token - succeeds using the stored id-token", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + TokenType: "Bearer", + ExpiresIn: 10, + AccessToken: "access-token", + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) + require.NotNil(t, resp.GetOkResponse()) + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, expiredTokenResponse.IDToken, "access-token") + requireStoredTokens(t, store, sessionID, true) + requireStoredTokens(t, store, newSessionID, false) + }, + }, + { + name: "IDP server returns an invalid JWT as id-token - succeeds using the stored id-token", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: "not-a-jwt", + TokenType: "Bearer", + ExpiresIn: 10, + AccessToken: "access-token", + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) + require.NotNil(t, resp.GetOkResponse()) + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, expiredTokenResponse.IDToken, "access-token") + requireStoredTokens(t, store, sessionID, true) + requireStoredTokens(t, store, newSessionID, false) + }, + }, + { + name: "IDP server returns an id-token signed with unknown key", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, unknownJWKPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns an id-token with non-matching nonce", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", "non-matching-nonce")), + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "IDP server returns an id-token with no nonce claim - succeeds as it is not required", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: validIDTokenWithoutNonce, + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) + require.NotNil(t, resp.GetOkResponse()) + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, validIDTokenWithoutNonce, "access-token") + requireStoredTokens(t, store, sessionID, true) + requireStoredTokens(t, store, newSessionID, false) + }, + }, + { + name: "IDP server returns an id-token with non-matching audience", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"non-matching-audience"}).Claim("nonce", newNonce)), + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + requireStoredState(t, store, sessionID, false) + }, + }, + { + name: "succeed", + req: withSessionHeader, + storedTokenResponse: expiredTokenResponse, + mockTokensResponse: &idpTokensResponse{ + IDToken: validIDToken, + AccessToken: "access-token", + TokenType: "Bearer", + ExpiresIn: 10, + }, + responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { + require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) + require.NotNil(t, resp.GetOkResponse()) + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, validIDToken, "access-token") + requireStoredTokens(t, store, sessionID, true) + requireStoredTokens(t, store, newSessionID, false) + }, + }, + } + + for _, tt := range refreshTokensTests { + t.Run("refresh tokens: "+tt.name, func(t *testing.T) { + idpServer.Start() + t.Cleanup(func() { + idpServer.Stop() + require.NoError(t, store.RemoveSession(ctx, sessionID)) + require.NoError(t, store.RemoveSession(ctx, newSessionID)) + }) + + idpServer.tokensResponse = tt.mockTokensResponse + idpServer.statusCode = tt.mockStatusCode + if tt.mockStatusCode <= 0 { + idpServer.statusCode = http.StatusOK + } + + if tt.storedAuthState == nil { + tt.storedAuthState = validAuthState + } + require.NoError(t, store.SetAuthorizationState(ctx, sessionID, tt.storedAuthState)) + if tt.storedTokenResponse != nil { + require.NoError(t, store.SetTokenResponse(ctx, sessionID, tt.storedTokenResponse)) + } + + resp := &envoy.CheckResponse{} + require.NoError(t, h.Process(ctx, tt.req, resp)) + tt.responseVerify(t, resp) + }) + } } func TestOIDCProcessWithFailingSessionStore(t *testing.T) { @@ -570,7 +867,7 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { idpServer := newServer() idpServer.statusCode = http.StatusOK - idpServer.tokensResponse = &tokensResponse{ + idpServer.tokensResponse = &idpTokensResponse{ IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), AccessToken: "access-token", TokenType: "Bearer", @@ -610,7 +907,51 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { require.NoError(t, h.Process(ctx, callbackRequest, resp)) requireSessionErrorResponse(t, resp) }) + } + + // The following subset of tests is testing the refresh tokens requests, so there's expected communication with the IDP server. + // The store is expected to fail in some way, so the handler should return an error response. + refreshTokensTests := []struct { + name string + storeCallsToError map[int]bool + wantRedirect bool + }{ + { + name: "refresh tokens - fails to get the authorization state", + storeCallsToError: map[int]bool{getAuthorizationState: true}, + wantRedirect: true, + }, + { + name: "refresh tokens - fails to set new token response", + storeCallsToError: map[int]bool{setTokenResponse: true}, + wantRedirect: false, + }, + } + + for _, tt := range refreshTokensTests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, store.SetAuthorizationState(ctx, sessionID, validAuthState)) + require.NoError(t, store.SetTokenResponse(ctx, sessionID, &oidc.TokenResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday).Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + RefreshToken: "refresh-token", + AccessToken: "access-token", + AccessTokenExpiresAt: yesterday, + })) + store.errs = tt.storeCallsToError + t.Cleanup(func() { store.errs = nil }) + + resp := &envoy.CheckResponse{} + require.NoError(t, h.Process(ctx, withSessionHeader, resp)) + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + if tt.wantRedirect { + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + } else { + requireSessionErrorResponse(t, resp) + } + }) } } @@ -638,23 +979,42 @@ func TestOIDCProcessWithFailingJWKSProvider(t *testing.T) { require.NoError(t, store.RemoveSession(ctx, sessionID)) }) - idpServer.tokensResponse = &tokensResponse{ + idpServer.tokensResponse = &idpTokensResponse{ IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), AccessToken: "access-token", TokenType: "Bearer", } idpServer.statusCode = http.StatusOK - // Set the authorization state in the store, so it can be found by the handler + expiredTokenResponse := &oidc.TokenResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday).Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + RefreshToken: "refresh-token", + AccessToken: "access-token", + AccessTokenExpiresAt: yesterday, + } + require.NoError(t, store.SetAuthorizationState(ctx, sessionID, validAuthState)) - resp := &envoy.CheckResponse{} - err = h.Process(ctx, callbackRequest, resp) - require.NoError(t, err) + t.Run("callback request ", func(t *testing.T) { + resp := &envoy.CheckResponse{} + require.NoError(t, h.Process(ctx, callbackRequest, resp)) + require.Equal(t, int32(codes.Internal), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireStoredTokens(t, store, sessionID, false) + }) + + require.NoError(t, store.SetTokenResponse(ctx, sessionID, expiredTokenResponse)) - require.Equal(t, int32(codes.Internal), resp.GetStatus().GetCode()) - requireStandardResponseHeaders(t, resp) - requireStoredTokens(t, store, sessionID, false) + t.Run("refresh tokens - redirect to reauthenticate", func(t *testing.T) { + resp := &envoy.CheckResponse{} + require.NoError(t, h.Process(ctx, withSessionHeader, resp)) + + require.Equal(t, int32(codes.Unauthenticated), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireRedirectResponse(t, resp.GetDeniedResponse(), wantRedirectBaseURI, wantRedirectParams) + requireCookie(t, resp.GetDeniedResponse()) + requireStoredState(t, store, newSessionID, true) + }) } func TestMatchesCallbackPath(t *testing.T) { @@ -1072,7 +1432,7 @@ func requireStandardResponseHeaders(t *testing.T, resp *envoy.CheckResponse) { type idpServer struct { server *http.Server listener *bufconn.Listener - tokensResponse *tokensResponse + tokensResponse *idpTokensResponse statusCode int }