Add a timeout parameter for end2end key queries.

Add a timeout parameter for controlling how long synapse will wait
for responses from remote servers. For servers that fail include how
they failed to make it easier to debug.

Fetch keys from different servers in parallel rather than in series.

Set the default timeout to 10s.
This commit is contained in:
Mark Haines 2016-09-12 18:17:09 +01:00
parent aa7b890cfe
commit 949c2c5435
5 changed files with 114 additions and 54 deletions

View File

@ -176,7 +176,7 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content): def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
@ -188,10 +188,12 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_device_keys") sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content) return self.transport_layer.query_client_keys(
destination, content, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content): def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
@ -203,7 +205,9 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_one_time_keys") sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content) return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View File

@ -298,7 +298,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def query_client_keys(self, destination, query_content): def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
@ -327,12 +327,13 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content): def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
Request: Request:
@ -363,6 +364,7 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import json import json
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api.errors import SynapseError, CodeMessageException
import synapse.types from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +30,6 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
@ -40,7 +39,7 @@ class E2eKeysHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def query_devices(self, query_body): def query_devices(self, query_body, timeout):
""" Handle a device key query from a client """ Handle a device key query from a client
{ {
@ -63,27 +62,50 @@ class E2eKeysHandler(object):
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict) local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items(): for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id) if self.is_mine_id(user_id):
queries_by_domain[user.domain][user_id] = device_ids local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries # do the queries
# TODO: do these in parallel failures = {}
results = {} results = {}
for destination, destination_query in queries_by_domain.items(): if local_query:
if destination == self.server_name: local_result = yield self.query_local_devices(local_query)
res = yield self.query_local_devices(destination_query) for user_id, keys in local_result.items():
else: if user_id in local_query:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
defer.returnValue((200, {"device_keys": results})) @defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
try:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"device_keys": results, "failures": failures,
}))
@defer.inlineCallbacks @defer.inlineCallbacks
def query_local_devices(self, query): def query_local_devices(self, query):
@ -104,7 +126,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s", logger.warning("Request for keys for non-local user %s",
user_id) user_id)
raise errors.SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
if not device_ids: if not device_ids:
local_query.append((user_id, None)) local_query.append((user_id, None))

View File

@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False): long_retries=False, timeout=None):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=True): def post_json(self, destination, path, data={}, long_retries=True,
timeout=None):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=True, long_retries=True,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:

View File

@ -19,11 +19,12 @@ import simplejson as json
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors from synapse.api.errors import SynapseError, CodeMessageException
import synapse.server from synapse.http.servlet import (
import synapse.types RestServlet, parse_json_object_from_request, parse_integer
from synapse.http.servlet import RestServlet, parse_json_object_from_request )
from synapse.types import UserID from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
device_id = requester.device_id device_id = requester.device_id
if device_id is None: if device_id is None:
raise synapse.api.errors.SynapseError( raise SynapseError(
400, 400,
"To upload keys, you must pass device_id when authenticating" "To upload keys, you must pass device_id when authenticating"
) )
@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body) result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices( result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}} {"device_keys": {user_id: device_ids}},
timeout,
) )
defer.returnValue(result) defer.returnValue(result)
@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request( result = yield self.handle_request(
{"one_time_keys": {user_id: {device_id: algorithm}}} {"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body) result = yield self.handle_request(body, timeout)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_request(self, body): def handle_request(self, body, timeout):
local_query = [] local_query = []
remote_queries = {} remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items(): for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id) if self.is_mine_id(user_id):
if self.is_mine(user):
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else:
remote_queries.setdefault(user.domain, {})[user_id] = ( domain = get_domain_from_id(user_id)
device_keys remote_queries.setdefault(domain, {})[user_id] = device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query) results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {} json_result = {}
failures = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items(): for key_id, json_bytes in keys.items():
@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
key_id: json.loads(json_bytes) key_id: json.loads(json_bytes)
} }
for destination, device_keys in remote_queries.items(): @defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys( remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys} destination,
{"one_time_keys": device_keys},
timeout=timeout
) )
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys: if user_id in device_keys:
json_result[user_id] = keys json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
defer.returnValue((200, {"one_time_keys": json_result})) yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"one_time_keys": json_result,
"failures": failures
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):