E2E keys: Make federation query share code with client query
Refactor the e2e query handler to separate out the local query, and then make the federation handler use it.
This commit is contained in:
parent
986615b0b2
commit
1efee2f52b
|
@ -348,27 +348,9 @@ class FederationServer(FederationBase):
|
||||||
(200, send_content)
|
(200, send_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
@log_function
|
@log_function
|
||||||
def on_query_client_keys(self, origin, content):
|
def on_query_client_keys(self, origin, content):
|
||||||
query = []
|
return self.on_query_request("client_keys", content)
|
||||||
for user_id, device_ids in content.get("device_keys", {}).items():
|
|
||||||
if not device_ids:
|
|
||||||
query.append((user_id, None))
|
|
||||||
else:
|
|
||||||
for device_id in device_ids:
|
|
||||||
query.append((user_id, device_id))
|
|
||||||
|
|
||||||
results = yield self.store.get_e2e_device_keys(query)
|
|
||||||
|
|
||||||
json_result = {}
|
|
||||||
for user_id, device_keys in results.items():
|
|
||||||
for device_id, json_bytes in device_keys.items():
|
|
||||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
|
||||||
json_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({"device_keys": json_result})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
|
|
@ -367,10 +367,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
||||||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||||
PATH = "/user/keys/query"
|
PATH = "/user/keys/query"
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_POST(self, origin, content, query):
|
def on_POST(self, origin, content, query):
|
||||||
response = yield self.handler.on_query_client_keys(origin, content)
|
return self.handler.on_query_client_keys(origin, content)
|
||||||
defer.returnValue((200, response))
|
|
||||||
|
|
||||||
|
|
||||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||||
|
|
|
@ -13,12 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api import errors
|
||||||
import synapse.types
|
import synapse.types
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -29,39 +32,101 @@ class E2eKeysHandler(BaseHandler):
|
||||||
super(E2eKeysHandler, self).__init__(hs)
|
super(E2eKeysHandler, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
self.is_mine = hs.is_mine
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
|
# doesn't really work as part of the generic query API, because the
|
||||||
|
# query request requires an object POST, but we abuse the
|
||||||
|
# "query handler" interface.
|
||||||
|
self.federation.register_query_handler(
|
||||||
|
"client_keys", self.on_federation_query_client_keys
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_devices(self, query_body):
|
def query_devices(self, query_body):
|
||||||
local_query = []
|
""" Handle a device key query from a client
|
||||||
remote_queries = {}
|
|
||||||
for user_id, device_ids in query_body.get("device_keys", {}).items():
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
->
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
|
|
||||||
|
# separate users by domain.
|
||||||
|
# make a map from domain to user_id to device_ids
|
||||||
|
queries_by_domain = collections.defaultdict(dict)
|
||||||
|
for user_id, device_ids in device_keys_query.items():
|
||||||
user = synapse.types.UserID.from_string(user_id)
|
user = synapse.types.UserID.from_string(user_id)
|
||||||
if self.is_mine(user):
|
queries_by_domain[user.domain][user_id] = device_ids
|
||||||
|
|
||||||
|
# do the queries
|
||||||
|
# TODO: do these in parallel
|
||||||
|
results = {}
|
||||||
|
for destination, destination_query in queries_by_domain.items():
|
||||||
|
if destination == self.hs.hostname:
|
||||||
|
res = yield self.query_local_devices(destination_query)
|
||||||
|
else:
|
||||||
|
res = yield self.federation.query_client_keys(
|
||||||
|
destination, {"device_keys": destination_query}
|
||||||
|
)
|
||||||
|
res = res["device_keys"]
|
||||||
|
for user_id, keys in res.items():
|
||||||
|
if user_id in destination_query:
|
||||||
|
results[user_id] = keys
|
||||||
|
|
||||||
|
defer.returnValue((200, {"device_keys": results}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def query_local_devices(self, query):
|
||||||
|
"""Get E2E device keys for local users
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (dict[string, list[string]|None): map from user_id to a list
|
||||||
|
of devices to query (None for all devices)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
|
||||||
|
map from user_id -> device_id -> device details
|
||||||
|
"""
|
||||||
|
local_query = []
|
||||||
|
|
||||||
|
for user_id, device_ids in query.items():
|
||||||
|
if not self.is_mine_id(user_id):
|
||||||
|
logger.warning("Request for keys for non-local user %s",
|
||||||
|
user_id)
|
||||||
|
raise errors.SynapseError(400, "Not a user here")
|
||||||
|
|
||||||
if not device_ids:
|
if not device_ids:
|
||||||
local_query.append((user_id, None))
|
local_query.append((user_id, None))
|
||||||
else:
|
else:
|
||||||
for device_id in device_ids:
|
for device_id in device_ids:
|
||||||
local_query.append((user_id, device_id))
|
local_query.append((user_id, device_id))
|
||||||
else:
|
|
||||||
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
|
||||||
device_ids
|
|
||||||
)
|
|
||||||
results = yield self.store.get_e2e_device_keys(local_query)
|
results = yield self.store.get_e2e_device_keys(local_query)
|
||||||
|
|
||||||
json_result = {}
|
# un-jsonify the results
|
||||||
|
json_result = collections.defaultdict(dict)
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, json_bytes in device_keys.items():
|
for device_id, json_bytes in device_keys.items():
|
||||||
json_result.setdefault(user_id, {})[
|
json_result[user_id][device_id] = json.loads(json_bytes)
|
||||||
device_id] = json.loads(
|
|
||||||
json_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
for destination, device_keys in remote_queries.items():
|
defer.returnValue(json_result)
|
||||||
remote_result = yield self.federation.query_client_keys(
|
|
||||||
destination, {"device_keys": device_keys}
|
@defer.inlineCallbacks
|
||||||
)
|
def on_federation_query_client_keys(self, query_body):
|
||||||
for user_id, keys in remote_result["device_keys"].items():
|
""" Handle a device key query from a federated server
|
||||||
if user_id in device_keys:
|
"""
|
||||||
json_result[user_id] = keys
|
device_keys_query = query_body.get("device_keys", {})
|
||||||
defer.returnValue((200, {"device_keys": json_result}))
|
res = yield self.query_local_devices(device_keys_query)
|
||||||
|
defer.returnValue({"device_keys": res})
|
||||||
|
|
Loading…
Reference in New Issue