Merge remote-tracking branch 'origin/develop' into dbkr/make_notif_highlight_query_fast

This commit is contained in:
David Baker 2016-09-09 19:11:34 +01:00
commit b91e2833b3
32 changed files with 969 additions and 258 deletions

View File

@ -1,3 +1,47 @@
Changes in synapse v0.17.3 (2016-09-09)
=======================================
This release fixes a major bug that stopped servers from handling rooms with
over 1000 members.
Changes in synapse v0.17.2 (2016-09-08)
=======================================
This release contains security bug fixes. Please upgrade.
No changes since v0.17.2-rc1
Changes in synapse v0.17.2-rc1 (2016-09-05)
===========================================
Features:
* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
#1062, #1066)
Changes:
* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
#1068)
* Don't notify for online to online presence transitions. (PR #1054)
* Occasionally persist unpersisted presence updates (PR #1055)
* Allow application services to have an optional 'url' (PR #1056)
* Clean up old sent transactions from DB (PR #1059)
Bug fixes:
* Fix None check in backfill (PR #1043)
* Fix membership changes to be idempotent (PR #1067)
* Fix bug in get_pdu where it would sometimes return events with incorrect
signature
Changes in synapse v0.17.1 (2016-08-24) Changes in synapse v0.17.1 (2016-08-24)
======================================= =======================================

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.1" __version__ = "0.17.3"

View File

@ -583,12 +583,15 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0] access_token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -629,17 +632,19 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request_args): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0] get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue(None)
if "user_id" not in request_args: if "user_id" not in request.args:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0] user_id = request.args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
@ -833,7 +838,9 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = request.args["access_token"][0] token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
@ -1142,3 +1149,40 @@ class Auth(object):
"This server requires you to be a moderator in the room to" "This server requires you to be a moderator in the room to"
" edit its room list entry" " edit its room list entry"
) )
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View File

@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable" APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol) urllib.quote(protocol)
) )
try: try:
defer.returnValue((yield self.get_json(uri, {}))) info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex) uri, ex)
defer.returnValue({}) defer.returnValue(None)
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or ( return self.protocol_meta_cache.get(key) or (

View File

@ -137,6 +137,12 @@ class FederationClient(FederationBase):
self._transaction_queue.enqueue_edu(edu) self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None) return defer.succeed(None)
@log_function
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function @log_function
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination) self._transaction_queue.enqueue_failure(failure, destination)

View File

@ -188,7 +188,7 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e) logger.exception("Failed to handle edu %r", edu_type)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -81,6 +81,8 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -155,10 +157,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
if not self.can_send_to(destination):
return
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
@ -171,39 +182,15 @@ class TransactionQueue(object):
) )
return return
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
)
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination) yield run_on_reactor()
txn_id = str(self._next_txn_id) while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@ -211,13 +198,85 @@ class TransactionQueue(object):
self.store, self.store,
) )
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures) len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -242,9 +301,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures), len(failures),
) )
with limiter: with limiter:
@ -299,12 +358,14 @@ class TransactionQueue(object):
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
except NotRetryingDestination: success = False
logger.info( else:
"TX [%s] not ready for retry yet - " # Remove the acknowledged device messages from the database
"dropping transaction for now", if should_delete_from_device_stream:
destination, yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
) )
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@ -314,6 +375,8 @@ class TransactionQueue(object):
e, e,
) )
success = False
for p in pdus: for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e: except Exception as e:
@ -325,9 +388,9 @@ class TransactionQueue(object):
e, e,
) )
success = False
for p in pdus: for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.info("Failed to send event %s to %s", p.event_id, destination)
finally: defer.returnValue(success)
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)

View File

@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:
for p in s.protocols: for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols) defer.returnValue(protocols)

View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)

View File

@ -265,6 +265,12 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives to_federation_ping = {} # These need sending keep-alives
# Only bother handling the last presence change for each user
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
new_state = new_states_dict.values()
for new_state in new_states: for new_state in new_states:
user_id = new_state.user_id user_id = new_state.user_id
@ -651,6 +657,13 @@ class PresenceHandler(object):
) )
continue continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
origin, user_id,
)
continue
presence_state = push.get("presence", None) presence_state = push.get("presence", None)
if not presence_state: if not presence_state:
logger.info( logger.info(

View File

@ -443,6 +443,16 @@ class RoomListHandler(BaseHandler):
self.remote_list_request_cache.set((), deferred) self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name):
res = yield self.hs.get_replication_layer().get_public_rooms(
[server_name]
)
if server_name not in res:
raise SynapseError(404, "Server not found")
defer.returnValue(res[server_name])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_aggregated_public_room_list(self): def get_aggregated_public_room_list(self):
""" """

View File

@ -199,7 +199,14 @@ class TypingHandler(object):
user_id = content["user_id"] user_id = content["user_id"]
# Check that the string is a valid user id # Check that the string is a valid user id
UserID.from_string(user_id) user = UserID.from_string(user_id)
if user.domain != origin:
logger.info(
"Got typing update from %r with bad 'user_id': %r",
origin, user_id,
)
return
users = yield self.state.get_current_user_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users) domains = set(get_domain_from_id(u) for u in users)

View File

@ -181,7 +181,7 @@ class ReplicationResource(Resource):
def replicate(self, request_streams, limit): def replicate(self, request_streams, limit):
writer = _Writer() writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
@ -195,7 +195,7 @@ class ReplicationResource(Resource):
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):

View File

@ -16,13 +16,18 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(BaseSlavedStore): class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker( self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id", db_conn, "device_max_stream_id", "stream_id",
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
) )
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
@ -38,5 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
stream = result.get("to_device") stream = result.get("to_device")
if stream: if stream:
self._device_inbox_id_gen.advance(int(stream["position"])) self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
return super(SlavedDeviceInboxStore, self).process_replication(result) return super(SlavedDeviceInboxStore, self).process_replication(result)

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
try: access_token = get_access_token_from_request(request)
access_token = request.args["access_token"][0]
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
yield self.store.delete_access_token(access_token) yield self.store.delete_access_token(access_token)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
if "access_token" not in request.args: as_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
if "user" not in register_json: if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.") raise SynapseError(400, "Expected 'user' key.")
as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
if "access_token" not in request.args: access_token = get_access_token_from_request(request)
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0] access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")

View File

@ -22,8 +22,8 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging import logging
import urllib import urllib
@ -120,6 +120,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -134,6 +136,11 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND 404, "Event not found.", errcode=Codes.NOT_FOUND
) )
if format == "event":
event = format_event_for_client_v2(data.get_dict())
defer.returnValue((200, event))
elif format == "content":
defer.returnValue((200, data.get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -295,14 +302,25 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
server = parse_string(request, "server", default=None)
try: try:
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
except AuthError: except AuthError as e:
# This endpoint isn't authed, but its useful to know who's hitting # We allow people to not be authed if they're just looking at our
# it if they *do* supply an access token # room list, but require auth when we proxy the request.
# In both cases we call the auth function, as that has the side
# effect of logging who issued this request if an access token was
# provided.
if server:
raise e
else:
pass pass
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(server)
else:
data = yield handler.get_aggregated_public_room_list() data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data)) defer.returnValue((200, data))

View File

@ -17,6 +17,8 @@
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -90,6 +92,6 @@ class HttpTransactionStore(object):
return response return response
def _get_key(self, request): def _get_key(self, request):
token = request.args["access_token"][0] token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0] path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token return path_without_txn_id + "/" + token

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username'] desired_username = body['username']
appservice = None appservice = None
if 'access_token' in request.args: if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which # fork off as soon as possible for ASes and shared secret auth which
@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) desired_username = body.get("user", desired_username)
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration( result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body desired_username, access_token, body
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return

View File

@ -16,10 +16,11 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.v1.transactions import HttpTransactionStore
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
# TODO: Prod the notifier to wake up sync streams. sender_user_id = requester.user.to_string()
# TODO: Implement replication for the messages.
# TODO: Send the messages to remote servers if needed.
local_messages = {} yield self.device_message_handler.send_device_message(
for user_id, by_device in content["messages"].items(): sender_user_id, message_type, content["messages"]
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": requester.user.to_string(),
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
) )
response = (200, {}) response = (200, {})

View File

@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols)) defer.returnValue((200, protocols))
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
releases=())
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
)
if protocol in protocols:
defer.returnValue((200, protocols[protocol]))
else:
defer.returnValue((404, {"error": "Unknown protocol"}))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
@ -57,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
@ -81,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
fields = request.args fields = request.args
del fields["access_token"] fields.pop("access_token", None)
results = yield self.appservice_handler.query_3pe( results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields
@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)

View File

@ -35,6 +35,7 @@ from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
@ -100,6 +101,7 @@ class HomeServer(object):
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -205,6 +207,9 @@ class HomeServer(object):
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)
def build_e2e_keys_handler(self): def build_e2e_keys_handler(self):
return E2eKeysHandler(self) return E2eKeysHandler(self)

View File

@ -111,7 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
self._device_inbox_id_gen = StreamIdGenerator( self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id" db_conn, "device_max_stream_id", "stream_id"
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
@ -182,6 +182,30 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn, "device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View File

@ -133,9 +133,11 @@ class BackgroundUpdateStore(SQLBaseStore):
updates = yield self._simple_select_list( updates = yield self._simple_select_list(
"background_updates", "background_updates",
keyvalues=None, keyvalues=None,
retcols=("update_name",), retcols=("update_name", "depends_on"),
) )
in_flight = set(update["update_name"] for update in updates)
for update in updates: for update in updates:
if update["depends_on"] not in in_flight:
self._background_update_queue.append(update['update_name']) self._background_update_queue.append(update['update_name'])
if not self._background_update_queue: if not self._background_update_queue:

View File

@ -27,37 +27,157 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore): class DeviceInboxStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, messages_by_user_then_device): def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
""" remote_messages_by_destination):
"""Used to send messages from this server.
Args: Args:
messages_by_user_and_device(dict): sender_user_id(str): The ID of the user sending these messages.
local_messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message. Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict):
Dictionary of destination server_name to the EDU JSON to send.
Returns: Returns:
A deferred stream_id that resolves when the messages have been A deferred stream_id that resolves when the messages have been
inserted. inserted.
""" """
def select_devices_txn(txn, user_id, devices): def add_messages_txn(txn, now_ms, stream_id):
if not devices: # Add the local messages directly to the local inbox.
return [] self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
# Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a
# federation transaction to that destination.
sql = ( sql = (
"SELECT user_id, device_id FROM devices" "INSERT INTO device_federation_outbox"
" (destination, stream_id, queued_ts, messages_json)"
" VALUES (?,?,?,?)"
)
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = ujson.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
self, origin, message_id, local_messages_by_user_then_device
):
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
txn, table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",),
allow_none=True,
)
if already_inserted is not None:
return
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
txn, table="device_federation_inbox",
values={
"origin": origin,
"message_id": message_id,
"received_ts": now_ms,
},
)
# Add the messages to the approriate local device inboxes so that
# they'll be sent to the devices when they next sync.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device):
sql = (
"UPDATE device_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall():
# Add the message for all devices for this user on this
# server.
device = row[0]
messages_json_for_user[device] = message_json
else:
if not devices:
continue
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ? AND device_id IN (" " WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices)) + ",".join("?" * len(devices))
+ ")" + ")"
) )
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
args = [user_id] + devices txn.execute(sql, [user_id] + devices)
txn.execute(sql, args) for row in txn.fetchall():
return [tuple(row) for row in txn.fetchall()] # Only insert into the local inbox if the device exists on
# this server
device = row[0]
message_json = ujson.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
def add_messages_to_device_inbox_txn(txn, stream_id): if messages_json_for_user:
local_users_and_devices = set() local_by_user_then_device[user_id] = messages_json_for_user
for user_id, messages_by_device in messages_by_user_then_device.items():
local_users_and_devices.update( if not local_by_user_then_device:
select_devices_txn(txn, user_id, messages_by_device.keys()) return
)
sql = ( sql = (
"INSERT INTO device_inbox" "INSERT INTO device_inbox"
@ -65,25 +185,12 @@ class DeviceInboxStore(SQLBaseStore):
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
rows = [] rows = []
for user_id, messages_by_device in messages_by_user_then_device.items(): for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message in messages_by_device.items(): for device_id, message_json in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json)) rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_to_device_inbox_txn,
stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
def get_new_messages_for_device( def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100 self, user_id, device_id, last_stream_id, current_stream_id, limit=100
): ):
@ -97,6 +204,12 @@ class DeviceInboxStore(SQLBaseStore):
Deferred ([dict], int): List of messages for the device and where Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to. in the stream the messages got to.
""" """
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn): def get_new_messages_for_device_txn(txn):
sql = ( sql = (
"SELECT stream_id, message_json FROM device_inbox" "SELECT stream_id, message_json FROM device_inbox"
@ -182,3 +295,71 @@ class DeviceInboxStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int): The last position of the device message stream
that the server sent up to.
current_stream_id(int): The current position of the device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)

View File

@ -343,7 +343,7 @@ class EventPushActionsStore(SQLBaseStore):
def f(txn): def f(txn):
before_clause = "" before_clause = ""
if before: if before:
before_clause = "AND stream_ordering < ?" before_clause = "AND epa.stream_ordering < ?"
args = [user_id, before, limit] args = [user_id, before, limit]
else: else:
args = [user_id, limit] args = [user_id, limit]

View File

@ -402,7 +402,7 @@ class RoomMemberStore(SQLBaseStore):
keyvalues={ keyvalues={
"membership": Membership.JOIN, "membership": Membership.JOIN,
}, },
batch_size=1000, batch_size=500,
desc="_get_joined_users_from_context", desc="_get_joined_users_from_context",
) )

View File

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');

View File

@ -0,0 +1,39 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP TABLE IF EXISTS device_federation_outbox;
CREATE TABLE device_federation_outbox (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
queued_ts BIGINT NOT NULL,
messages_json TEXT NOT NULL
);
DROP INDEX IF EXISTS device_federation_outbox_destination_id;
CREATE INDEX device_federation_outbox_destination_id
ON device_federation_outbox(destination, stream_id);
DROP TABLE IF EXISTS device_federation_inbox;
CREATE TABLE device_federation_inbox (
origin TEXT NOT NULL,
message_id TEXT NOT NULL,
received_ts BIGINT NOT NULL
);
DROP INDEX IF EXISTS device_federation_inbox_sender_id;
CREATE INDEX device_federation_inbox_sender_id
ON device_federation_inbox(origin, message_id);

View File

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE device_max_stream_id (
stream_id BIGINT NOT NULL
);
INSERT INTO device_max_stream_id (stream_id)
SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox;

View File

@ -48,6 +48,7 @@ class StateStore(SQLBaseStore):
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
def __init__(self, hs): def __init__(self, hs):
super(StateStore, self).__init__(hs) super(StateStore, self).__init__(hs)
@ -55,6 +56,10 @@ class StateStore(SQLBaseStore):
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state, self._background_deduplicate_state,
) )
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
@ -793,3 +798,31 @@ class StateStore(SQLBaseStore):
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@defer.inlineCallbacks
def _background_index_state(self, progress, batch_size):
def reindex_txn(txn):
if isinstance(self.database_engine, PostgresEngine):
txn.execute(
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
else:
txn.execute(
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
yield self.runInteraction(
self.STATE_GROUP_INDEX_UPDATE_NAME, reindex_txn
)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
defer.returnValue(1)

View File

@ -121,6 +121,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
lambda *args, **kargs: ([], 0)
)
self.datastore.delete_device_msgs_for_remote = (
lambda *args, **kargs: None
)
# Some local users to test with # Some local users to test with
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")