mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-25 09:18:14 +00:00 
			
		
		
		
	Cache token introspection response from OIDC provider (#16117)
This commit is contained in:
		
							parent
							
								
									eb0dbab15b
								
							
						
					
					
						commit
						54a51ff6c1
					
				
							
								
								
									
										1
									
								
								changelog.d/16117.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/16117.misc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | Cache token introspection response from OIDC provider. | ||||||
| @ -39,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable | |||||||
| from synapse.types import Requester, UserID, create_requester | from synapse.types import Requester, UserID, create_requester | ||||||
| from synapse.util import json_decoder | from synapse.util import json_decoder | ||||||
| from synapse.util.caches.cached_call import RetryOnExceptionCachedCall | from synapse.util.caches.cached_call import RetryOnExceptionCachedCall | ||||||
|  | from synapse.util.caches.expiringcache import ExpiringCache | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from synapse.server import HomeServer |     from synapse.server import HomeServer | ||||||
| @ -106,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth): | |||||||
| 
 | 
 | ||||||
|         self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) |         self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) | ||||||
| 
 | 
 | ||||||
|  |         self._clock = hs.get_clock() | ||||||
|  |         self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache( | ||||||
|  |             cache_name="introspection_token_cache", | ||||||
|  |             clock=self._clock, | ||||||
|  |             max_len=10000, | ||||||
|  |             expiry_ms=5 * 60 * 1000, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|         if isinstance(auth_method, PrivateKeyJWTWithKid): |         if isinstance(auth_method, PrivateKeyJWTWithKid): | ||||||
|             # Use the JWK as the client secret when using the private_key_jwt method |             # Use the JWK as the client secret when using the private_key_jwt method | ||||||
|             assert self._config.jwk, "No JWK provided" |             assert self._config.jwk, "No JWK provided" | ||||||
| @ -144,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth): | |||||||
|         Returns: |         Returns: | ||||||
|             The introspection response |             The introspection response | ||||||
|         """ |         """ | ||||||
|  |         # check the cache before doing a request | ||||||
|  |         introspection_token = self._token_cache.get(token, None) | ||||||
|  | 
 | ||||||
|  |         if introspection_token: | ||||||
|  |             # check the expiration field of the token (if it exists) | ||||||
|  |             exp = introspection_token.get("exp", None) | ||||||
|  |             if exp: | ||||||
|  |                 time_now = self._clock.time() | ||||||
|  |                 expired = time_now > exp | ||||||
|  |                 if not expired: | ||||||
|  |                     return introspection_token | ||||||
|  |             else: | ||||||
|  |                 return introspection_token | ||||||
|  | 
 | ||||||
|         metadata = await self._issuer_metadata.get() |         metadata = await self._issuer_metadata.get() | ||||||
|         introspection_endpoint = metadata.get("introspection_endpoint") |         introspection_endpoint = metadata.get("introspection_endpoint") | ||||||
|         raw_headers: Dict[str, str] = { |         raw_headers: Dict[str, str] = { | ||||||
| @ -157,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth): | |||||||
| 
 | 
 | ||||||
|         # Fill the body/headers with credentials |         # Fill the body/headers with credentials | ||||||
|         uri, raw_headers, body = self._client_auth.prepare( |         uri, raw_headers, body = self._client_auth.prepare( | ||||||
|             method="POST", uri=introspection_endpoint, headers=raw_headers, body=body |             method="POST", | ||||||
|  |             uri=introspection_endpoint, | ||||||
|  |             headers=raw_headers, | ||||||
|  |             body=body, | ||||||
|         ) |         ) | ||||||
|         headers = Headers({k: [v] for (k, v) in raw_headers.items()}) |         headers = Headers({k: [v] for (k, v) in raw_headers.items()}) | ||||||
| 
 | 
 | ||||||
| @ -187,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth): | |||||||
|                 "The introspection endpoint returned an invalid JSON response." |                 "The introspection endpoint returned an invalid JSON response." | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         return IntrospectionToken(**resp) |         expiration = resp.get("exp", None) | ||||||
|  |         if expiration: | ||||||
|  |             if self._clock.time() > expiration: | ||||||
|  |                 raise InvalidClientTokenError("Token is expired.") | ||||||
|  | 
 | ||||||
|  |         introspection_token = IntrospectionToken(**resp) | ||||||
|  | 
 | ||||||
|  |         # add token to cache | ||||||
|  |         self._token_cache[token] = introspection_token | ||||||
|  | 
 | ||||||
|  |         return introspection_token | ||||||
| 
 | 
 | ||||||
|     async def is_server_admin(self, requester: Requester) -> bool: |     async def is_server_admin(self, requester: Requester) -> bool: | ||||||
|         return "urn:synapse:admin:*" in requester.scope |         return "urn:synapse:admin:*" in requester.scope | ||||||
|  | |||||||
| @ -491,6 +491,68 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||||
|         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) |         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | ||||||
|         self.assertEqual(error.value.code, 503) |         self.assertEqual(error.value.code, 503) | ||||||
| 
 | 
 | ||||||
|  |     def test_introspection_token_cache(self) -> None: | ||||||
|  |         access_token = "open_sesame" | ||||||
|  |         self.http_client.request = simple_async_mock( | ||||||
|  |             return_value=FakeResponse.json( | ||||||
|  |                 code=200, | ||||||
|  |                 payload={"active": "true", "scope": "guest", "jti": access_token}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # first call should cache response | ||||||
|  |         # Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code | ||||||
|  |         # for regular auth code via the config | ||||||
|  |         self.get_success( | ||||||
|  |             self.auth._introspect_token(access_token)  # type: ignore[attr-defined] | ||||||
|  |         ) | ||||||
|  |         introspection_token = self.auth._token_cache.get(access_token)  # type: ignore[attr-defined] | ||||||
|  |         self.assertEqual(introspection_token["jti"], access_token) | ||||||
|  |         # there's been one http request | ||||||
|  |         self.http_client.request.assert_called_once() | ||||||
|  | 
 | ||||||
|  |         # second call should pull from cache, there should still be only one http request | ||||||
|  |         token = self.get_success(self.auth._introspect_token(access_token))  # type: ignore[attr-defined] | ||||||
|  |         self.http_client.request.assert_called_once() | ||||||
|  |         self.assertEqual(token["jti"], access_token) | ||||||
|  | 
 | ||||||
|  |         # advance past five minutes and check that cache expired - there should be more than one http call now | ||||||
|  |         self.reactor.advance(360) | ||||||
|  |         token_2 = self.get_success(self.auth._introspect_token(access_token))  # type: ignore[attr-defined] | ||||||
|  |         self.assertEqual(self.http_client.request.call_count, 2) | ||||||
|  |         self.assertEqual(token_2["jti"], access_token) | ||||||
|  | 
 | ||||||
|  |         # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a | ||||||
|  |         # token with a soon-to-expire `exp` field to the cache | ||||||
|  |         self.http_client.request = simple_async_mock( | ||||||
|  |             return_value=FakeResponse.json( | ||||||
|  |                 code=200, | ||||||
|  |                 payload={ | ||||||
|  |                     "active": "true", | ||||||
|  |                     "scope": "guest", | ||||||
|  |                     "jti": "stale", | ||||||
|  |                     "exp": self.clock.time() + 100, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.get_success( | ||||||
|  |             self.auth._introspect_token("stale")  # type: ignore[attr-defined] | ||||||
|  |         ) | ||||||
|  |         introspection_token = self.auth._token_cache.get("stale")  # type: ignore[attr-defined] | ||||||
|  |         self.assertEqual(introspection_token["jti"], "stale") | ||||||
|  |         self.assertEqual(self.http_client.request.call_count, 1) | ||||||
|  | 
 | ||||||
|  |         # advance the reactor past the token expiry but less than the cache expiry | ||||||
|  |         self.reactor.advance(120) | ||||||
|  |         self.assertEqual(self.auth._token_cache.get("stale"), introspection_token)  # type: ignore[attr-defined] | ||||||
|  | 
 | ||||||
|  |         # check that the next call causes another http request (which will fail because the token is technically expired | ||||||
|  |         # but the important thing is we discard the token from the cache and try the network) | ||||||
|  |         self.get_failure( | ||||||
|  |             self.auth._introspect_token("stale"), InvalidClientTokenError  # type: ignore[attr-defined] | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(self.http_client.request.call_count, 2) | ||||||
|  | 
 | ||||||
|     def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: |     def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: | ||||||
|         # We only generate a master key to simplify the test. |         # We only generate a master key to simplify the test. | ||||||
|         master_signing_key = generate_signing_key(device_id) |         master_signing_key = generate_signing_key(device_id) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user