Use ObservableDeferreds instead of Deferreds as they behave as intended

This commit is contained in:
Kegan Dougal 2016-11-11 14:54:10 +00:00
parent c7daf3136c
commit 42c43cfafd
2 changed files with 31 additions and 30 deletions

View File

@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.util.async import ObservableDeferred
from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer parse_json_object_from_request, parse_string, parse_integer
@ -57,14 +58,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request) res_deferred = ObservableDeferred(self.on_POST(request))
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -218,14 +219,14 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request, room_id, event_type, txn_id) res_deferred = ObservableDeferred(self.on_POST(request, room_id, event_type, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)
@ -287,14 +288,14 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request, room_identifier, txn_id) res_deferred = ObservableDeferred(self.on_POST(request, room_identifier, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)
@ -541,14 +542,14 @@ class RoomForgetRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request, room_id, txn_id) res_deferred = ObservableDeferred(self.on_POST(request, room_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)
@ -624,15 +625,15 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = ObservableDeferred(self.txns.get_client_transaction(request, txn_id))
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request, room_id, membership_action, txn_id) res_deferred = ObservableDeffself.on_POST(request, room_id, membership_action, txn_id)
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)
@ -669,14 +670,14 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
try: try:
res_deferred = self.txns.get_client_transaction(request, txn_id) res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
except KeyError: except KeyError:
pass pass
res_deferred = self.on_POST(request, room_id, event_id, txn_id) res_deferred = ObservableDeferred(self.on_POST(request, room_id, event_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred) self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred response = yield res_deferred.observe()
defer.returnValue(response) defer.returnValue(response)

View File

@ -25,32 +25,32 @@ logger = logging.getLogger(__name__)
class HttpTransactionCache(object): class HttpTransactionCache(object):
def __init__(self): def __init__(self):
# { key : (txn_id, response_deferred) } # { key : (txn_id, res_observ_defer) }
self.transactions = {} self.transactions = {}
def _get_response(self, key, txn_id): def _get_response(self, key, txn_id):
try: try:
(last_txn_id, response_deferred) = self.transactions[key] (last_txn_id, res_observ_defer) = self.transactions[key]
if txn_id == last_txn_id: if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id) logger.info("get_response: Returning a response for %s", txn_id)
return response_deferred return res_observ_defer
except KeyError: except KeyError:
pass pass
return None return None
def _store_response(self, key, txn_id, response_deferred): def _store_response(self, key, txn_id, res_observ_defer):
self.transactions[key] = (txn_id, response_deferred) self.transactions[key] = (txn_id, res_observ_defer)
def store_client_transaction(self, request, txn_id, response_deferred): def store_client_transaction(self, request, txn_id, res_observ_defer):
"""Stores the request/Promise<response> pair of an HTTP transaction. """Stores the request/Promise<response> pair of an HTTP transaction.
Args: Args:
request (twisted.web.http.Request): The twisted HTTP request. This request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment. request must have the transaction ID as the last path segment.
response_deferred (Promise<tuple>): A tuple of (response code, response dict) res_observ_defer (Promise<tuple>): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request. txn_id (str): The transaction ID for this request.
""" """
self._store_response(self._get_key(request), txn_id, response_deferred) self._store_response(self._get_key(request), txn_id, res_observ_defer)
def get_client_transaction(self, request, txn_id): def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one. """Retrieves a stored response if there was one.
@ -64,10 +64,10 @@ class HttpTransactionCache(object):
Raises: Raises:
KeyError if the transaction was not found. KeyError if the transaction was not found.
""" """
response_deferred = self._get_response(self._get_key(request), txn_id) res_observ_defer = self._get_response(self._get_key(request), txn_id)
if response_deferred is None: if res_observ_defer is None:
raise KeyError("Transaction not found.") raise KeyError("Transaction not found.")
return response_deferred return res_observ_defer
def _get_key(self, request): def _get_key(self, request):
token = get_access_token_from_request(request) token = get_access_token_from_request(request)