mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-10-30 19:58:36 +00:00 
			
		
		
		
	Pass the Requester down to the HttpTransactionCache. (#15200)
This commit is contained in:
		
							parent
							
								
									820f02b70b
								
							
						
					
					
						commit
						47bc84dd53
					
				
							
								
								
									
										1
									
								
								changelog.d/15200.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/15200.misc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | ||||
| Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key. | ||||
| @ -12,7 +12,7 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from http import HTTPStatus | ||||
| from typing import TYPE_CHECKING, Awaitable, Optional, Tuple | ||||
| from typing import TYPE_CHECKING, Optional, Tuple | ||||
| 
 | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import NotFoundError, SynapseError | ||||
| @ -23,10 +23,10 @@ from synapse.http.servlet import ( | ||||
|     parse_json_object_from_request, | ||||
| ) | ||||
| from synapse.http.site import SynapseRequest | ||||
| from synapse.rest.admin import assert_requester_is_admin | ||||
| from synapse.rest.admin._base import admin_patterns | ||||
| from synapse.logging.opentracing import set_tag | ||||
| from synapse.rest.admin._base import admin_patterns, assert_user_is_admin | ||||
| from synapse.rest.client.transactions import HttpTransactionCache | ||||
| from synapse.types import JsonDict, UserID | ||||
| from synapse.types import JsonDict, Requester, UserID | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from synapse.server import HomeServer | ||||
| @ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet): | ||||
|             self.__class__.__name__, | ||||
|         ) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|         self, request: SynapseRequest, txn_id: Optional[str] = None | ||||
|     async def _do( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         txn_id: Optional[str], | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         await assert_requester_is_admin(self.auth, request) | ||||
|         await assert_user_is_admin(self.auth, requester) | ||||
|         body = parse_json_object_from_request(request) | ||||
|         assert_params_in_dict(body, ("user_id", "content")) | ||||
|         event_type = body.get("type", EventTypes.Message) | ||||
| @ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet): | ||||
| 
 | ||||
|         return HTTPStatus.OK, {"event_id": event.event_id} | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_POST( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         return await self._do(request, requester, None) | ||||
| 
 | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, txn_id | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         set_tag("txn_id", txn_id) | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, requester, self._do, request, requester, txn_id | ||||
|         ) | ||||
|  | ||||
| @ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process | ||||
| from synapse.rest.client._base import client_patterns | ||||
| from synapse.rest.client.transactions import HttpTransactionCache | ||||
| from synapse.streams.config import PaginationConfig | ||||
| from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID | ||||
| from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID | ||||
| from synapse.types.state import StateFilter | ||||
| from synapse.util import json_decoder | ||||
| from synapse.util.cancellation import cancellable | ||||
| @ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet): | ||||
|         PATTERNS = "/createRoom" | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         set_tag("txn_id", txn_id) | ||||
|         return self.txns.fetch_or_execute_request(request, self.on_POST, request) | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, requester, self._do, request, requester | ||||
|         ) | ||||
| 
 | ||||
|     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         return await self._do(request, requester) | ||||
| 
 | ||||
|     async def _do( | ||||
|         self, request: SynapseRequest, requester: Requester | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         room_id, _, _ = await self._room_creation_handler.create_room( | ||||
|             requester, self.get_room_config(request) | ||||
|         ) | ||||
| @ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet): | ||||
| 
 | ||||
| 
 | ||||
| # TODO: Needs unit testing for generic events | ||||
| class RoomStateEventRestServlet(TransactionRestServlet): | ||||
| class RoomStateEventRestServlet(RestServlet): | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__(hs) | ||||
|         super().__init__() | ||||
|         self.event_creation_handler = hs.get_event_creation_handler() | ||||
|         self.room_member_handler = hs.get_room_member_handler() | ||||
|         self.message_handler = hs.get_message_handler() | ||||
| @ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet): | ||||
|     def register(self, http_server: HttpServer) -> None: | ||||
|         # /rooms/$roomid/send/$event_type[/$txn_id] | ||||
|         PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" | ||||
|         register_txn_path(self, PATTERNS, http_server, with_get=True) | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|     async def _do( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         room_id: str, | ||||
|         event_type: str, | ||||
|         txn_id: Optional[str] = None, | ||||
|         txn_id: Optional[str], | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         event_dict: JsonDict = { | ||||
| @ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet): | ||||
|         set_tag("event_id", event_id) | ||||
|         return 200, {"event_id": event_id} | ||||
| 
 | ||||
|     def on_GET( | ||||
|         self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str | ||||
|     ) -> Tuple[int, str]: | ||||
|         return 200, "Not implemented" | ||||
|     async def on_POST( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         room_id: str, | ||||
|         event_type: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         return await self._do(request, requester, room_id, event_type, None) | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         set_tag("txn_id", txn_id) | ||||
| 
 | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, room_id, event_type, txn_id | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, | ||||
|             requester, | ||||
|             self._do, | ||||
|             request, | ||||
|             requester, | ||||
|             room_id, | ||||
|             event_type, | ||||
|             txn_id, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): | ||||
|         PATTERNS = "/join/(?P<room_identifier>[^/]*)" | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|     async def _do( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         room_identifier: str, | ||||
|         txn_id: Optional[str] = None, | ||||
|         txn_id: Optional[str], | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         content = parse_json_object_from_request(request, allow_empty_body=True) | ||||
| 
 | ||||
|         # twisted.web.server.Request.args is incorrectly defined as Optional[Any] | ||||
| @ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): | ||||
| 
 | ||||
|         return 200, {"room_id": room_id} | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_POST( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         room_identifier: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         return await self._do(request, requester, room_identifier, None) | ||||
| 
 | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, room_identifier: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         set_tag("txn_id", txn_id) | ||||
| 
 | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, room_identifier, txn_id | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, requester, self._do, request, requester, room_identifier, txn_id | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| # TODO: Needs unit testing | ||||
| class PublicRoomListRestServlet(TransactionRestServlet): | ||||
| class PublicRoomListRestServlet(RestServlet): | ||||
|     PATTERNS = client_patterns("/publicRooms$", v1=True) | ||||
| 
 | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         super().__init__(hs) | ||||
|         super().__init__() | ||||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
| @ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet): | ||||
|         PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|         self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=False) | ||||
| 
 | ||||
|     async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]: | ||||
|         await self.room_member_handler.forget(user=requester.user, room_id=room_id) | ||||
| 
 | ||||
|         return 200, {} | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_POST( | ||||
|         self, request: SynapseRequest, room_id: str | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=False) | ||||
|         return await self._do(requester, room_id) | ||||
| 
 | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, room_id: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=False) | ||||
|         set_tag("txn_id", txn_id) | ||||
| 
 | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, room_id, txn_id | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, requester, self._do, requester, room_id | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet): | ||||
|         ) | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|     async def _do( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         room_id: str, | ||||
|         membership_action: str, | ||||
|         txn_id: Optional[str] = None, | ||||
|         txn_id: Optional[str], | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
| 
 | ||||
|         if requester.is_guest and membership_action not in { | ||||
|             Membership.JOIN, | ||||
|             Membership.LEAVE, | ||||
| @ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet): | ||||
| 
 | ||||
|         return 200, return_value | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_POST( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         room_id: str, | ||||
|         membership_action: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         return await self._do(request, requester, room_id, membership_action, None) | ||||
| 
 | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         set_tag("txn_id", txn_id) | ||||
| 
 | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, room_id, membership_action, txn_id | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, | ||||
|             requester, | ||||
|             self._do, | ||||
|             request, | ||||
|             requester, | ||||
|             room_id, | ||||
|             membership_action, | ||||
|             txn_id, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet): | ||||
|         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" | ||||
|         register_txn_path(self, PATTERNS, http_server) | ||||
| 
 | ||||
|     async def on_POST( | ||||
|     async def _do( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         room_id: str, | ||||
|         event_id: str, | ||||
|         txn_id: Optional[str] = None, | ||||
|         txn_id: Optional[str], | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         try: | ||||
| @ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet): | ||||
|         set_tag("event_id", event_id) | ||||
|         return 200, {"event_id": event_id} | ||||
| 
 | ||||
|     def on_PUT( | ||||
|     async def on_POST( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         room_id: str, | ||||
|         event_id: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         return await self._do(request, requester, room_id, event_id, None) | ||||
| 
 | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request) | ||||
|         set_tag("txn_id", txn_id) | ||||
| 
 | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self.on_POST, request, room_id, event_id, txn_id | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, requester, self._do, request, requester, room_id, event_id, txn_id | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -1224,7 +1280,6 @@ def register_txn_path( | ||||
|     servlet: RestServlet, | ||||
|     regex_string: str, | ||||
|     http_server: HttpServer, | ||||
|     with_get: bool = False, | ||||
| ) -> None: | ||||
|     """Registers a transaction-based path. | ||||
| 
 | ||||
| @ -1236,7 +1291,6 @@ def register_txn_path( | ||||
|         regex_string: The regex string to register. Must NOT have a | ||||
|             trailing $ as this string will be appended to. | ||||
|         http_server: The http_server to register paths with. | ||||
|         with_get: True to also register respective GET paths for the PUTs. | ||||
|     """ | ||||
|     on_POST = getattr(servlet, "on_POST", None) | ||||
|     on_PUT = getattr(servlet, "on_PUT", None) | ||||
| @ -1254,18 +1308,6 @@ def register_txn_path( | ||||
|         on_PUT, | ||||
|         servlet.__class__.__name__, | ||||
|     ) | ||||
|     on_GET = getattr(servlet, "on_GET", None) | ||||
|     if with_get: | ||||
|         if on_GET is None: | ||||
|             raise RuntimeError( | ||||
|                 "register_txn_path called with with_get = True, but no on_GET method exists" | ||||
|             ) | ||||
|         http_server.register_paths( | ||||
|             "GET", | ||||
|             client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), | ||||
|             on_GET, | ||||
|             servlet.__class__.__name__, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class TimestampLookupRestServlet(RestServlet): | ||||
|  | ||||
| @ -13,7 +13,7 @@ | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| from typing import TYPE_CHECKING, Awaitable, Tuple | ||||
| from typing import TYPE_CHECKING, Tuple | ||||
| 
 | ||||
| from synapse.http import servlet | ||||
| from synapse.http.server import HttpServer | ||||
| @ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r | ||||
| from synapse.http.site import SynapseRequest | ||||
| from synapse.logging.opentracing import set_tag | ||||
| from synapse.rest.client.transactions import HttpTransactionCache | ||||
| from synapse.types import JsonDict | ||||
| from synapse.types import JsonDict, Requester | ||||
| 
 | ||||
| from ._base import client_patterns | ||||
| 
 | ||||
| @ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet): | ||||
|         self.txns = HttpTransactionCache(hs) | ||||
|         self.device_message_handler = hs.get_device_message_handler() | ||||
| 
 | ||||
|     def on_PUT( | ||||
|         self, request: SynapseRequest, message_type: str, txn_id: str | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|         set_tag("txn_id", txn_id) | ||||
|         return self.txns.fetch_or_execute_request( | ||||
|             request, self._put, request, message_type, txn_id | ||||
|         ) | ||||
| 
 | ||||
|     async def _put( | ||||
|     async def on_PUT( | ||||
|         self, request: SynapseRequest, message_type: str, txn_id: str | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         requester = await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         set_tag("txn_id", txn_id) | ||||
|         return await self.txns.fetch_or_execute_request( | ||||
|             request, | ||||
|             requester, | ||||
|             self._put, | ||||
|             request, | ||||
|             requester, | ||||
|             message_type, | ||||
|         ) | ||||
| 
 | ||||
|     async def _put( | ||||
|         self, | ||||
|         request: SynapseRequest, | ||||
|         requester: Requester, | ||||
|         message_type: str, | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         content = parse_json_object_from_request(request) | ||||
|         assert_params_in_dict(content, ("messages",)) | ||||
| 
 | ||||
|  | ||||
| @ -15,16 +15,16 @@ | ||||
| """This module contains logic for storing HTTP PUT transactions. This is used | ||||
| to ensure idempotency when performing PUTs using the REST API.""" | ||||
| import logging | ||||
| from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple | ||||
| from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple | ||||
| 
 | ||||
| from typing_extensions import ParamSpec | ||||
| 
 | ||||
| from twisted.internet.defer import Deferred | ||||
| from twisted.python.failure import Failure | ||||
| from twisted.web.server import Request | ||||
| from twisted.web.iweb import IRequest | ||||
| 
 | ||||
| from synapse.logging.context import make_deferred_yieldable, run_in_background | ||||
| from synapse.types import JsonDict | ||||
| from synapse.types import JsonDict, Requester | ||||
| from synapse.util.async_helpers import ObservableDeferred | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
| @ -41,53 +41,47 @@ P = ParamSpec("P") | ||||
| class HttpTransactionCache: | ||||
|     def __init__(self, hs: "HomeServer"): | ||||
|         self.hs = hs | ||||
|         self.auth = self.hs.get_auth() | ||||
|         self.clock = self.hs.get_clock() | ||||
|         # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) | ||||
|         self.transactions: Dict[ | ||||
|             str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] | ||||
|             Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] | ||||
|         ] = {} | ||||
|         # Try to clean entries every 30 mins. This means entries will exist | ||||
|         # for at *LEAST* 30 mins, and at *MOST* 60 mins. | ||||
|         self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) | ||||
| 
 | ||||
|     def _get_transaction_key(self, request: Request) -> str: | ||||
|     def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: | ||||
|         """A helper function which returns a transaction key that can be used | ||||
|         with TransactionCache for idempotent requests. | ||||
| 
 | ||||
|         Idempotency is based on the returned key being the same for separate | ||||
|         requests to the same endpoint. The key is formed from the HTTP request | ||||
|         path and the access_token for the requesting user. | ||||
|         path and attributes from the requester: the access_token_id for regular users, | ||||
|         the user ID for guest users, and the appservice ID for appservice users. | ||||
| 
 | ||||
|         Args: | ||||
|             request: The incoming request. Must contain an access_token. | ||||
|             request: The incoming request. | ||||
|             requester: The requester doing the request. | ||||
|         Returns: | ||||
|             A transaction key | ||||
|         """ | ||||
|         assert request.path is not None | ||||
|         token = self.auth.get_access_token_from_request(request) | ||||
|         return request.path.decode("utf8") + "/" + token | ||||
|         path: str = request.path.decode("utf8") | ||||
|         if requester.is_guest: | ||||
|             assert requester.user is not None, "Guest requester must have a user ID set" | ||||
|             return (path, "guest", requester.user) | ||||
|         elif requester.app_service is not None: | ||||
|             return (path, "appservice", requester.app_service.id) | ||||
|         else: | ||||
|             assert ( | ||||
|                 requester.access_token_id is not None | ||||
|             ), "Requester must have an access_token_id" | ||||
|             return (path, "user", requester.access_token_id) | ||||
| 
 | ||||
|     def fetch_or_execute_request( | ||||
|         self, | ||||
|         request: Request, | ||||
|         fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], | ||||
|         *args: P.args, | ||||
|         **kwargs: P.kwargs, | ||||
|     ) -> Awaitable[Tuple[int, JsonDict]]: | ||||
|         """A helper function for fetch_or_execute which extracts | ||||
|         a transaction key from the given request. | ||||
| 
 | ||||
|         See: | ||||
|             fetch_or_execute | ||||
|         """ | ||||
|         return self.fetch_or_execute( | ||||
|             self._get_transaction_key(request), fn, *args, **kwargs | ||||
|         ) | ||||
| 
 | ||||
|     def fetch_or_execute( | ||||
|         self, | ||||
|         txn_key: str, | ||||
|         request: IRequest, | ||||
|         requester: Requester, | ||||
|         fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], | ||||
|         *args: P.args, | ||||
|         **kwargs: P.kwargs, | ||||
| @ -96,14 +90,15 @@ class HttpTransactionCache: | ||||
|         to produce a response for this transaction. | ||||
| 
 | ||||
|         Args: | ||||
|             txn_key: A key to ensure idempotency should fetch_or_execute be | ||||
|                 called again at a later point in time. | ||||
|             request: | ||||
|             requester: | ||||
|             fn: A function which returns a tuple of (response_code, response_dict). | ||||
|             *args: Arguments to pass to fn. | ||||
|             **kwargs: Keyword arguments to pass to fn. | ||||
|         Returns: | ||||
|             Deferred which resolves to a tuple of (response_code, response_dict). | ||||
|         """ | ||||
|         txn_key = self._get_transaction_key(request, requester) | ||||
|         if txn_key in self.transactions: | ||||
|             observable = self.transactions[txn_key][0] | ||||
|         else: | ||||
|  | ||||
| @ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | ||||
|         self.cache = HttpTransactionCache(self.hs) | ||||
| 
 | ||||
|         self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"}) | ||||
|         self.mock_key = "foo" | ||||
| 
 | ||||
|         # Here we make sure that we're setting all the fields that HttpTransactionCache | ||||
|         # uses to build the transaction key. | ||||
|         self.mock_request = Mock() | ||||
|         self.mock_request.path = b"/foo/bar" | ||||
|         self.mock_requester = Mock() | ||||
|         self.mock_requester.app_service = None | ||||
|         self.mock_requester.is_guest = False | ||||
|         self.mock_requester.access_token_id = 1234 | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_executes_given_function( | ||||
|         self, | ||||
|     ) -> Generator["defer.Deferred[Any]", object, None]: | ||||
|         cb = Mock(return_value=make_awaitable(self.mock_http_response)) | ||||
|         res = yield self.cache.fetch_or_execute( | ||||
|             self.mock_key, cb, "some_arg", keyword="arg" | ||||
|         res = yield self.cache.fetch_or_execute_request( | ||||
|             self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" | ||||
|         ) | ||||
|         cb.assert_called_once_with("some_arg", keyword="arg") | ||||
|         self.assertEqual(res, self.mock_http_response) | ||||
| @ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | ||||
|     ) -> Generator["defer.Deferred[Any]", object, None]: | ||||
|         cb = Mock(return_value=make_awaitable(self.mock_http_response)) | ||||
|         for i in range(3):  # invoke multiple times | ||||
|             res = yield self.cache.fetch_or_execute( | ||||
|                 self.mock_key, cb, "some_arg", keyword="arg", changing_args=i | ||||
|             res = yield self.cache.fetch_or_execute_request( | ||||
|                 self.mock_request, | ||||
|                 self.mock_requester, | ||||
|                 cb, | ||||
|                 "some_arg", | ||||
|                 keyword="arg", | ||||
|                 changing_args=i, | ||||
|             ) | ||||
|             self.assertEqual(res, self.mock_http_response) | ||||
|         # expect only a single call to do the work | ||||
| @ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | ||||
|         @defer.inlineCallbacks | ||||
|         def test() -> Generator["defer.Deferred[Any]", object, None]: | ||||
|             with LoggingContext("c") as c1: | ||||
|                 res = yield self.cache.fetch_or_execute(self.mock_key, cb) | ||||
|                 res = yield self.cache.fetch_or_execute_request( | ||||
|                     self.mock_request, self.mock_requester, cb | ||||
|                 ) | ||||
|                 self.assertIs(current_context(), c1) | ||||
|                 self.assertEqual(res, (1, {})) | ||||
| 
 | ||||
| @ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|         with LoggingContext("test") as test_context: | ||||
|             try: | ||||
|                 yield self.cache.fetch_or_execute(self.mock_key, cb) | ||||
|                 yield self.cache.fetch_or_execute_request( | ||||
|                     self.mock_request, self.mock_requester, cb | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 self.assertEqual(e.args[0], "boo") | ||||
|             self.assertIs(current_context(), test_context) | ||||
| 
 | ||||
|             res = yield self.cache.fetch_or_execute(self.mock_key, cb) | ||||
|             res = yield self.cache.fetch_or_execute_request( | ||||
|                 self.mock_request, self.mock_requester, cb | ||||
|             ) | ||||
|             self.assertEqual(res, self.mock_http_response) | ||||
|             self.assertIs(current_context(), test_context) | ||||
| 
 | ||||
| @ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|         with LoggingContext("test") as test_context: | ||||
|             try: | ||||
|                 yield self.cache.fetch_or_execute(self.mock_key, cb) | ||||
|                 yield self.cache.fetch_or_execute_request( | ||||
|                     self.mock_request, self.mock_requester, cb | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 self.assertEqual(e.args[0], "boo") | ||||
|             self.assertIs(current_context(), test_context) | ||||
| 
 | ||||
|             res = yield self.cache.fetch_or_execute(self.mock_key, cb) | ||||
|             res = yield self.cache.fetch_or_execute_request( | ||||
|                 self.mock_request, self.mock_requester, cb | ||||
|             ) | ||||
|             self.assertEqual(res, self.mock_http_response) | ||||
|             self.assertIs(current_context(), test_context) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: | ||||
|         cb = Mock(return_value=make_awaitable(self.mock_http_response)) | ||||
|         yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") | ||||
|         yield self.cache.fetch_or_execute_request( | ||||
|             self.mock_request, self.mock_requester, cb, "an arg" | ||||
|         ) | ||||
|         # should NOT have cleaned up yet | ||||
|         self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) | ||||
| 
 | ||||
|         yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") | ||||
|         yield self.cache.fetch_or_execute_request( | ||||
|             self.mock_request, self.mock_requester, cb, "an arg" | ||||
|         ) | ||||
|         # still using cache | ||||
|         cb.assert_called_once_with("an arg") | ||||
| 
 | ||||
|         self.clock.advance_time_msec(CLEANUP_PERIOD_MS) | ||||
| 
 | ||||
|         yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") | ||||
|         yield self.cache.fetch_or_execute_request( | ||||
|             self.mock_request, self.mock_requester, cb, "an arg" | ||||
|         ) | ||||
|         # no longer using cache | ||||
|         self.assertEqual(cb.call_count, 2) | ||||
|         self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")]) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user