Merge pull request #1110 from matrix-org/markjh/e2e_timeout
Add a timeout parameter for end2end key queries.
This commit is contained in:
commit
76b09c29b0
|
@ -176,7 +176,7 @@ class FederationClient(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def query_client_keys(self, destination, content):
|
def query_client_keys(self, destination, content, timeout):
|
||||||
"""Query device keys for a device hosted on a remote server.
|
"""Query device keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -188,10 +188,12 @@ class FederationClient(FederationBase):
|
||||||
response
|
response
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.inc("client_device_keys")
|
sent_queries_counter.inc("client_device_keys")
|
||||||
return self.transport_layer.query_client_keys(destination, content)
|
return self.transport_layer.query_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, content):
|
def claim_client_keys(self, destination, content, timeout):
|
||||||
"""Claims one-time keys for a device hosted on a remote server.
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -203,7 +205,9 @@ class FederationClient(FederationBase):
|
||||||
response
|
response
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.inc("client_one_time_keys")
|
sent_queries_counter.inc("client_one_time_keys")
|
||||||
return self.transport_layer.claim_client_keys(destination, content)
|
return self.transport_layer.claim_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
|
|
@ -298,7 +298,7 @@ class TransportLayerClient(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def query_client_keys(self, destination, query_content):
|
def query_client_keys(self, destination, query_content, timeout):
|
||||||
"""Query the device keys for a list of user ids hosted on a remote
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
server.
|
server.
|
||||||
|
|
||||||
|
@ -327,12 +327,13 @@ class TransportLayerClient(object):
|
||||||
destination=destination,
|
destination=destination,
|
||||||
path=path,
|
path=path,
|
||||||
data=query_content,
|
data=query_content,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, query_content):
|
def claim_client_keys(self, destination, query_content, timeout):
|
||||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||||
|
|
||||||
Request:
|
Request:
|
||||||
|
@ -363,6 +364,7 @@ class TransportLayerClient(object):
|
||||||
destination=destination,
|
destination=destination,
|
||||||
path=path,
|
path=path,
|
||||||
data=query_content,
|
data=query_content,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api import errors
|
from synapse.api.errors import SynapseError, CodeMessageException
|
||||||
import synapse.types
|
from synapse.types import get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ class E2eKeysHandler(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.server_name = hs.hostname
|
|
||||||
|
|
||||||
# doesn't really work as part of the generic query API, because the
|
# doesn't really work as part of the generic query API, because the
|
||||||
# query request requires an object POST, but we abuse the
|
# query request requires an object POST, but we abuse the
|
||||||
|
@ -40,7 +39,7 @@ class E2eKeysHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_devices(self, query_body):
|
def query_devices(self, query_body, timeout):
|
||||||
""" Handle a device key query from a client
|
""" Handle a device key query from a client
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -63,27 +62,50 @@ class E2eKeysHandler(object):
|
||||||
|
|
||||||
# separate users by domain.
|
# separate users by domain.
|
||||||
# make a map from domain to user_id to device_ids
|
# make a map from domain to user_id to device_ids
|
||||||
queries_by_domain = collections.defaultdict(dict)
|
local_query = {}
|
||||||
|
remote_queries = {}
|
||||||
|
|
||||||
for user_id, device_ids in device_keys_query.items():
|
for user_id, device_ids in device_keys_query.items():
|
||||||
user = synapse.types.UserID.from_string(user_id)
|
if self.is_mine_id(user_id):
|
||||||
queries_by_domain[user.domain][user_id] = device_ids
|
local_query[user_id] = device_ids
|
||||||
|
else:
|
||||||
|
domain = get_domain_from_id(user_id)
|
||||||
|
remote_queries.setdefault(domain, {})[user_id] = device_ids
|
||||||
|
|
||||||
# do the queries
|
# do the queries
|
||||||
# TODO: do these in parallel
|
failures = {}
|
||||||
results = {}
|
results = {}
|
||||||
for destination, destination_query in queries_by_domain.items():
|
if local_query:
|
||||||
if destination == self.server_name:
|
local_result = yield self.query_local_devices(local_query)
|
||||||
res = yield self.query_local_devices(destination_query)
|
for user_id, keys in local_result.items():
|
||||||
else:
|
if user_id in local_query:
|
||||||
res = yield self.federation.query_client_keys(
|
|
||||||
destination, {"device_keys": destination_query}
|
|
||||||
)
|
|
||||||
res = res["device_keys"]
|
|
||||||
for user_id, keys in res.items():
|
|
||||||
if user_id in destination_query:
|
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
defer.returnValue((200, {"device_keys": results}))
|
@defer.inlineCallbacks
|
||||||
|
def do_remote_query(destination):
|
||||||
|
destination_query = remote_queries[destination]
|
||||||
|
try:
|
||||||
|
remote_result = yield self.federation.query_client_keys(
|
||||||
|
destination,
|
||||||
|
{"device_keys": destination_query},
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
for user_id, keys in remote_result["device_keys"].items():
|
||||||
|
if user_id in destination_query:
|
||||||
|
results[user_id] = keys
|
||||||
|
except CodeMessageException as e:
|
||||||
|
failures[destination] = {
|
||||||
|
"status": e.code, "message": e.message
|
||||||
|
}
|
||||||
|
|
||||||
|
yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
|
preserve_fn(do_remote_query)(destination)
|
||||||
|
for destination in remote_queries
|
||||||
|
]))
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"device_keys": results, "failures": failures,
|
||||||
|
}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_local_devices(self, query):
|
def query_local_devices(self, query):
|
||||||
|
@ -104,7 +126,7 @@ class E2eKeysHandler(object):
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
logger.warning("Request for keys for non-local user %s",
|
logger.warning("Request for keys for non-local user %s",
|
||||||
user_id)
|
user_id)
|
||||||
raise errors.SynapseError(400, "Not a user here")
|
raise SynapseError(400, "Not a user here")
|
||||||
|
|
||||||
if not device_ids:
|
if not device_ids:
|
||||||
local_query.append((user_id, None))
|
local_query.append((user_id, None))
|
||||||
|
|
|
@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def put_json(self, destination, path, data={}, json_data_callback=None,
|
def put_json(self, destination, path, data={}, json_data_callback=None,
|
||||||
long_retries=False):
|
long_retries=False, timeout=None):
|
||||||
""" Sends the specifed json data using PUT
|
""" Sends the specifed json data using PUT
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
|
||||||
use as the request body.
|
use as the request body.
|
||||||
long_retries (bool): A boolean that indicates whether we should
|
long_retries (bool): A boolean that indicates whether we should
|
||||||
retry for a short or long time.
|
retry for a short or long time.
|
||||||
|
timeout(int): How long to try (in ms) the destination for before
|
||||||
|
giving up. None indicates no timeout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
|
@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
headers_dict={"Content-Type": ["application/json"]},
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
long_retries=long_retries,
|
long_retries=long_retries,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
|
@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def post_json(self, destination, path, data={}, long_retries=True):
|
def post_json(self, destination, path, data={}, long_retries=True,
|
||||||
|
timeout=None):
|
||||||
""" Sends the specifed json data using POST
|
""" Sends the specifed json data using POST
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
|
||||||
the request body. This will be encoded as JSON.
|
the request body. This will be encoded as JSON.
|
||||||
long_retries (bool): A boolean that indicates whether we should
|
long_retries (bool): A boolean that indicates whether we should
|
||||||
retry for a short or long time.
|
retry for a short or long time.
|
||||||
|
timeout(int): How long to try (in ms) the destination for before
|
||||||
|
giving up. None indicates no timeout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||||
|
@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
|
||||||
body_callback=body_callback,
|
body_callback=body_callback,
|
||||||
headers_dict={"Content-Type": ["application/json"]},
|
headers_dict={"Content-Type": ["application/json"]},
|
||||||
long_retries=True,
|
long_retries=True,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
|
|
|
@ -19,11 +19,12 @@ import simplejson as json
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.api.errors
|
from synapse.api.errors import SynapseError, CodeMessageException
|
||||||
import synapse.server
|
from synapse.http.servlet import (
|
||||||
import synapse.types
|
RestServlet, parse_json_object_from_request, parse_integer
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
|
||||||
device_id = requester.device_id
|
device_id = requester.device_id
|
||||||
|
|
||||||
if device_id is None:
|
if device_id is None:
|
||||||
raise synapse.api.errors.SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"To upload keys, you must pass device_id when authenticating"
|
"To upload keys, you must pass device_id when authenticating"
|
||||||
)
|
)
|
||||||
|
@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id):
|
def on_POST(self, request, user_id, device_id):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.e2e_keys_handler.query_devices(body)
|
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id):
|
def on_GET(self, request, user_id, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
auth_user_id = requester.user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
user_id = user_id if user_id else auth_user_id
|
user_id = user_id if user_id else auth_user_id
|
||||||
device_ids = [device_id] if device_id else []
|
device_ids = [device_id] if device_id else []
|
||||||
result = yield self.e2e_keys_handler.query_devices(
|
result = yield self.e2e_keys_handler.query_devices(
|
||||||
{"device_keys": {user_id: device_ids}}
|
{"device_keys": {user_id: device_ids}},
|
||||||
|
timeout,
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id, algorithm):
|
def on_GET(self, request, user_id, device_id, algorithm):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
result = yield self.handle_request(
|
result = yield self.handle_request(
|
||||||
{"one_time_keys": {user_id: {device_id: algorithm}}}
|
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
||||||
|
timeout,
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id, algorithm):
|
def on_POST(self, request, user_id, device_id, algorithm):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.handle_request(body)
|
result = yield self.handle_request(body, timeout)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_request(self, body):
|
def handle_request(self, body, timeout):
|
||||||
local_query = []
|
local_query = []
|
||||||
remote_queries = {}
|
remote_queries = {}
|
||||||
|
|
||||||
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||||
user = UserID.from_string(user_id)
|
if self.is_mine_id(user_id):
|
||||||
if self.is_mine(user):
|
|
||||||
for device_id, algorithm in device_keys.items():
|
for device_id, algorithm in device_keys.items():
|
||||||
local_query.append((user_id, device_id, algorithm))
|
local_query.append((user_id, device_id, algorithm))
|
||||||
else:
|
else:
|
||||||
remote_queries.setdefault(user.domain, {})[user_id] = (
|
domain = get_domain_from_id(user_id)
|
||||||
device_keys
|
remote_queries.setdefault(domain, {})[user_id] = device_keys
|
||||||
)
|
|
||||||
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
||||||
|
|
||||||
json_result = {}
|
json_result = {}
|
||||||
|
failures = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, keys in device_keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
for key_id, json_bytes in keys.items():
|
for key_id, json_bytes in keys.items():
|
||||||
|
@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
key_id: json.loads(json_bytes)
|
key_id: json.loads(json_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
for destination, device_keys in remote_queries.items():
|
@defer.inlineCallbacks
|
||||||
|
def claim_client_keys(destination):
|
||||||
|
device_keys = remote_queries[destination]
|
||||||
|
try:
|
||||||
remote_result = yield self.federation.claim_client_keys(
|
remote_result = yield self.federation.claim_client_keys(
|
||||||
destination, {"one_time_keys": device_keys}
|
destination,
|
||||||
|
{"one_time_keys": device_keys},
|
||||||
|
timeout=timeout
|
||||||
)
|
)
|
||||||
for user_id, keys in remote_result["one_time_keys"].items():
|
for user_id, keys in remote_result["one_time_keys"].items():
|
||||||
if user_id in device_keys:
|
if user_id in device_keys:
|
||||||
json_result[user_id] = keys
|
json_result[user_id] = keys
|
||||||
|
except CodeMessageException as e:
|
||||||
|
failures[destination] = {
|
||||||
|
"status": e.code, "message": e.message
|
||||||
|
}
|
||||||
|
|
||||||
defer.returnValue((200, {"one_time_keys": json_result}))
|
yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
|
preserve_fn(claim_client_keys)(destination)
|
||||||
|
for destination in remote_queries
|
||||||
|
]))
|
||||||
|
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"one_time_keys": json_result,
|
||||||
|
"failures": failures
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
|
Loading…
Reference in New Issue