Merge branch 'release-v0.18.4' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-11-22 10:25:23 +00:00
commit f9834a3d1a
32 changed files with 383 additions and 430 deletions

View File

@ -1,3 +1,31 @@
Changes in synapse v0.18.4 (2016-11-22)
=======================================
Bug fixes:
* Add workaround for buggy clients that the fail to register (PR #1632)
Changes in synapse v0.18.4-rc1 (2016-11-14)
===========================================
Changes:
* Various database efficiency improvements (PR #1188, #1192)
* Update default config to blacklist more internal IPs, thanks to Euan Kemp (PR
#1198)
* Allow specifying duration in minutes in config, thanks to Daniel Dent (PR
#1625)
Bug fixes:
* Fix media repo to set CORs headers on responses (PR #1190)
* Fix registration to not error on non-ascii passwords (PR #1191)
* Fix create event code to limit the number of prev_events (PR #1615)
* Fix bug in transaction ID deduplication (PR #1624)
Changes in synapse v0.18.3 (2016-11-08) Changes in synapse v0.18.3 (2016-11-08)
======================================= =======================================

View File

@ -51,9 +51,9 @@ python_gc_counts reactor_gc_counts
The twisted-specific reactor metrics have been renamed. The twisted-specific reactor metrics have been renamed.
==================================== ================= ==================================== =====================
New name Old name New name Old name
------------------------------------ ----------------- ------------------------------------ ---------------------
python_twisted_reactor_pending_calls reactor_tick_time python_twisted_reactor_pending_calls reactor_pending_calls
python_twisted_reactor_tick_time reactor_tick_time python_twisted_reactor_tick_time reactor_tick_time
==================================== ================= ==================================== =====================

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.3" __version__ = "0.18.4"

View File

@ -34,6 +34,8 @@ from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -151,6 +153,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
if config.notify_appservices: if config.notify_appservices:

View File

@ -41,6 +41,8 @@ from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -165,6 +167,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)

View File

@ -39,6 +39,8 @@ from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -156,6 +158,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)

View File

@ -41,6 +41,8 @@ from synapse.api.urls import (
) )
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -162,6 +164,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)

View File

@ -36,6 +36,8 @@ from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -239,6 +241,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if config.start_pushers: if config.start_pushers:
sys.stderr.write( sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process" "\nThe pushers must be disabled in the main synapse process"

View File

@ -446,6 +446,8 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
ss = SynchrotronServer( ss = SynchrotronServer(

View File

@ -64,11 +64,12 @@ class Config(object):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
return value return value
second = 1000 second = 1000
hour = 60 * 60 * second minute = 60 * second
hour = 60 * minute
day = 24 * hour day = 24 * hour
week = 7 * day week = 7 * day
year = 365 * day year = 365 * day
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year} sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year}
size = 1 size = 1
suffix = value[-1] suffix = value[-1]
if suffix in sizes: if suffix in sizes:

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
import importlib import importlib
@ -39,7 +39,12 @@ class PasswordAuthProviderConfig(Config):
module = importlib.import_module(module) module = importlib.import_module(module)
provider_class = getattr(module, clz) provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"]) provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs): def default_config(self, **kwargs):

View File

@ -167,6 +167,8 @@ class ContentRepositoryConfig(Config):
# - '10.0.0.0/8' # - '10.0.0.0/8'
# - '172.16.0.0/12' # - '172.16.0.0/12'
# - '192.168.0.0/16' # - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
# #
# List of IP address CIDR ranges that the URL preview spider is allowed # List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist. # to access even if they are specified in url_preview_ip_range_blacklist.

View File

@ -653,7 +653,7 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password + self.hs.config.password_pepper, return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds)) bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):

View File

@ -34,6 +34,7 @@ from ._base import BaseHandler
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -415,6 +416,20 @@ class MessageHandler(BaseHandler):
builder.room_id, builder.room_id,
) )
# We want to limit the max number of prev events we point to in our
# new event
if len(latest_ret) > 10:
# Sort by reverse depth, so we point to the most recent.
latest_ret.sort(key=lambda a: -a[2])
new_latest_ret = latest_ret[:5]
# We also randomly point to some of the older events, to make
# sure that we don't completely ignore the older events.
if latest_ret[5:]:
sample_size = min(5, len(latest_ret[5:]))
new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
latest_ret = new_latest_ret
if latest_ret: if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1 depth = max([d for _, _, d in latest_ret]) + 1
else: else:

View File

@ -392,17 +392,30 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
if send_cors: if send_cors:
request.setHeader("Access-Control-Allow-Origin", "*") set_cors_headers(request)
request.setHeader("Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
request.setHeader("Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept")
request.write(json_bytes) request.write(json_bytes)
finish_request(request) finish_request(request)
return NOT_DONE_YET return NOT_DONE_YET
def set_cors_headers(request):
"""Set the CORs headers so that javascript running in a web browsers can
use this API
Args:
request (twisted.web.http.Request): The http request to add CORs to.
"""
request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
"Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept"
)
def finish_request(request): def finish_request(request):
""" Finish writing the response to the request. """ Finish writing the response to the request.

View File

@ -111,18 +111,20 @@ def render_all():
return "\n".join(strs) return "\n".join(strs)
reactor_metrics = get_metrics_for("reactor") register_process_collector(get_metrics_for("process"))
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
reactor_metrics.register_callback( python_metrics = get_metrics_for("python")
gc_time = python_metrics.register_distribution("gc_time", labels=["gen"])
gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
python_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"] "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
) )
register_process_collector(get_metrics_for("process")) reactor_metrics = get_metrics_for("python.twisted.reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func): def runUntilCurrentTimer(func):

View File

@ -13,12 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import os import os
import stat
from resource import getrusage, RUSAGE_SELF
TICKS_PER_SEC = 100 TICKS_PER_SEC = 100
@ -29,16 +24,6 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits") HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd") HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
# Field indexes from /proc/self/stat, taken from the proc(5) manpage # Field indexes from /proc/self/stat, taken from the proc(5) manpage
STAT_FIELDS = { STAT_FIELDS = {
"utime": 14, "utime": 14,
@ -49,9 +34,7 @@ STAT_FIELDS = {
} }
rusage = None
stats = {} stats = {}
fd_counts = None
# In order to report process_start_time_seconds we need to know the # In order to report process_start_time_seconds we need to know the
# machine's boot time, because the value in /proc/self/stat is relative to # machine's boot time, because the value in /proc/self/stat is relative to
@ -65,9 +48,6 @@ if HAVE_PROC_STAT:
def update_resource_metrics(): def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
if HAVE_PROC_SELF_STAT: if HAVE_PROC_SELF_STAT:
global stats global stats
with open("/proc/self/stat") as s: with open("/proc/self/stat") as s:
@ -80,52 +60,17 @@ def update_resource_metrics():
# we've lost the first two fields in PID and COMMAND above # we've lost the first two fields in PID and COMMAND above
stats[name] = int(raw_stats[index - 3]) stats[name] = int(raw_stats[index - 3])
global fd_counts
fd_counts = _process_fds()
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
def _count_fds():
# Not every OS will have a /proc/self/fd directory # Not every OS will have a /proc/self/fd directory
if not HAVE_PROC_SELF_FD: if not HAVE_PROC_SELF_FD:
return counts return 0
for fd in os.listdir("/proc/self/fd"): return len(os.listdir("/proc/self/fd"))
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
def register_process_collector(process_metrics): def register_process_collector(process_metrics):
# Legacy synapse-invented metric names process_metrics.register_collector(update_resource_metrics)
resource_metrics = process_metrics.make_subspace("resource")
resource_metrics.register_collector(update_resource_metrics)
# msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
process_metrics.register_callback("fds", _process_fds, labels=["type"])
# New prometheus-standard metric names
if HAVE_PROC_SELF_STAT: if HAVE_PROC_SELF_STAT:
process_metrics.register_callback( process_metrics.register_callback(
@ -158,7 +103,7 @@ def register_process_collector(process_metrics):
if HAVE_PROC_SELF_FD: if HAVE_PROC_SELF_FD:
process_metrics.register_callback( process_metrics.register_callback(
"open_fds", "open_fds",
lambda: sum(fd_counts.values()) lambda: _count_fds()
) )
if HAVE_PROC_SELF_LIMITS: if HAVE_PROC_SELF_LIMITS:

View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
try:
return self.transactions[txn_key][0].observe()
except (KeyError, IndexError):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()
def _cleanup(self):
now = self.clock.time_msec()
for key in self.transactions.keys():
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]

View File

@ -18,7 +18,8 @@
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore from synapse.rest.client.transactions import HttpTransactionCache
import re import re
import logging import logging
@ -59,4 +60,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionCache(hs.get_clock())

View File

@ -53,19 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
client_path_patterns("/createRoom(?:/.*)?$"), client_path_patterns("/createRoom(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
@defer.inlineCallbacks
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -214,19 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_type, txn_id): def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented") return (200, "Not implemented")
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, event_type, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
@ -283,19 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, {"room_id": room_id})) defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_identifier, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
@ -537,21 +510,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(
request, room_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
@ -623,21 +585,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
return False return False
return True return True
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, membership_action, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
@ -669,19 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, event_id, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):

View File

@ -1,97 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object):
def __init__(self):
# { key : (txn_id, response) }
self.transactions = {}
def get_response(self, key, txn_id):
"""Retrieve a response for this request.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request
Returns:
A tuple of (HTTP response code, response content) or None.
"""
try:
logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id)
return response
except KeyError:
pass
return None
def store_response(self, key, txn_id, response):
"""Stores an HTTP response tuple.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response):
"""Stores the request/response pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
response (tuple): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self.store_response(self._get_key(request), txn_id, response)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
Returns:
The response tuple.
Raises:
KeyError if the transaction was not found.
"""
response = self.get_response(self._get_key(request), txn_id)
if response is None:
raise KeyError("Transaction not found.")
return response
def _get_key(self, request):
token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View File

@ -169,6 +169,17 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None) guest_access_token = body.get("guest_access_token", None)
if (
'initial_device_display_name' in body and
'password' not in body
):
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
logger.warn("Ignoring initial_device_display_name without password")
del body['initial_device_display_name']
session_id = self.auth_handler.get_session_id(body) session_id = self.auth_handler.get_session_id(body)
registered_user_id = None registered_user_id = None
if session_id: if session_id:

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -40,18 +40,16 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionCache(hs.get_clock())
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self._put, request, message_type, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
@defer.inlineCallbacks
def _put(self, request, message_type, txn_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -63,7 +61,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
) )
response = (200, {}) response = (200, {})
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)

View File

@ -15,7 +15,7 @@
from ._base import parse_media_id, respond_with_file, respond_404 from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.server import request_handler from synapse.http.server import request_handler, set_cors_headers
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -45,6 +45,7 @@ class DownloadResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader( request.setHeader(
"Content-Security-Policy", "Content-Security-Policy",
"default-src 'none';" "default-src 'none';"

View File

@ -17,7 +17,7 @@
from ._base import parse_media_id, respond_404, respond_with_file from ._base import parse_media_id, respond_404, respond_with_file
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler from synapse.http.server import request_handler, set_cors_headers
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -48,6 +48,7 @@ class ThumbnailResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width") width = parse_integer(request, "width")
height = parse_integer(request, "height") height = parse_integer(request, "height")

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 37 SCHEMA_VERSION = 38
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -137,17 +137,8 @@ class PusherStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=1, max_entries=15000) @cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id): def get_if_user_has_pusher(self, user_id):
result = yield self._simple_select_many_batch( # This only exists for the cachedList decorator
table='pushers', raise NotImplementedError()
keyvalues={
'user_name': 'user_id',
},
retcol='user_name',
desc='get_if_user_has_pusher',
allow_none=True,
)
defer.returnValue(bool(result))
@cachedList(cached_method_name="get_if_user_has_pusher", @cachedList(cached_method_name="get_if_user_has_pusher",
list_name="user_ids", num_args=1, inlineCallbacks=True) list_name="user_ids", num_args=1, inlineCallbacks=True)

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('event_search_postgres_gist', '{}');

View File

@ -31,6 +31,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
def __init__(self, hs): def __init__(self, hs):
super(SearchStore, self).__init__(hs) super(SearchStore, self).__init__(hs)
@ -41,6 +42,10 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order self._background_reindex_search_order
) )
self.register_background_update_handler(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
self._background_reindex_gist_search
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size): def _background_reindex_search(self, progress, batch_size):
@ -139,6 +144,28 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _background_reindex_gist_search(self, progress, batch_size):
def create_index(conn):
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
" ON event_search USING GIST (vector)"
)
c.execute("DROP INDEX event_search_fts_idx")
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
defer.returnValue(1)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size): def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]

View File

@ -16,13 +16,12 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer, reactor from twisted.internet import defer
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import namedtuple from collections import namedtuple
import itertools
import logging import logging
import ujson as json import ujson as json
@ -50,20 +49,6 @@ class TransactionStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(TransactionStore, self).__init__(hs) super(TransactionStore, self).__init__(hs)
# New transactions that are currently in flights
self.inflight_transactions = {}
# Newly delievered transactions that *weren't* persisted while in flight
self.new_delivered_transactions = {}
# Newly delivered transactions that *were* persisted while in flight
self.update_delivered_transactions = {}
self.last_transaction = {}
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
self._clock.looping_call(self._persist_in_mem_txns, 1000)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
@ -148,46 +133,7 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
list: A list of previous transaction ids. list: A list of previous transaction ids.
""" """
return defer.succeed([])
auto_id = self._transaction_id_gen.get_next()
txn_row = _TransactionRow(
id=auto_id,
transaction_id=transaction_id,
destination=destination,
ts=origin_server_ts,
response_code=0,
response_json=None,
)
self.inflight_transactions.setdefault(destination, {})[transaction_id] = txn_row
prev_txn = self.last_transaction.get(destination)
if prev_txn:
return defer.succeed(prev_txn)
else:
return self.runInteraction(
"_get_prevs_txn",
self._get_prevs_txn,
destination,
)
def _get_prevs_txn(self, txn, destination):
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
# we can simply take the last one.
query = (
"SELECT * FROM sent_transactions"
" WHERE destination = ?"
" ORDER BY id DESC LIMIT 1"
)
txn.execute(query, (destination,))
results = self.cursor_to_dict(txn)
prev_txns = [r["transaction_id"] for r in results]
return prev_txns
def delivered_txn(self, transaction_id, destination, code, response_dict): def delivered_txn(self, transaction_id, destination, code, response_dict):
"""Persists the response for an outgoing transaction. """Persists the response for an outgoing transaction.
@ -198,52 +144,7 @@ class TransactionStore(SQLBaseStore):
code (int) code (int)
response_json (str) response_json (str)
""" """
pass
txn_row = self.inflight_transactions.get(
destination, {}
).pop(transaction_id, None)
self.last_transaction[destination] = transaction_id
if txn_row:
d = self.new_delivered_transactions.setdefault(destination, {})
d[transaction_id] = txn_row._replace(
response_code=code,
response_json=None, # For now, don't persist response
)
else:
d = self.update_delivered_transactions.setdefault(destination, {})
# For now, don't persist response
d[transaction_id] = _UpdateTransactionRow(code, None)
def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id.
Args:
transaction_id (str)
destination (str)
Returns:
list: A list of dicts
"""
return self.runInteraction(
"get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
def _get_transactions_after(self, txn, transaction_id, destination):
query = (
"SELECT * FROM sent_transactions"
" WHERE destination = ? AND id >"
" ("
" SELECT id FROM sent_transactions"
" WHERE transaction_id = ? AND destination = ?"
" )"
)
txn.execute(query, (destination, transaction_id, destination))
return self.cursor_to_dict(txn)
@cached(max_entries=10000) @cached(max_entries=10000)
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
@ -339,58 +240,11 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (self._clock.time_msec(),)) txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
@defer.inlineCallbacks
def _persist_in_mem_txns(self):
try:
inflight = self.inflight_transactions
new_delivered = self.new_delivered_transactions
update_delivered = self.update_delivered_transactions
self.inflight_transactions = {}
self.new_delivered_transactions = {}
self.update_delivered_transactions = {}
full_rows = [
row._asdict()
for txn_map in itertools.chain(inflight.values(), new_delivered.values())
for row in txn_map.values()
]
def f(txn):
if full_rows:
self._simple_insert_many_txn(
txn=txn,
table="sent_transactions",
values=full_rows
)
for dest, txn_map in update_delivered.items():
for txn_id, update_row in txn_map.items():
self._simple_update_one_txn(
txn,
table="sent_transactions",
keyvalues={
"transaction_id": txn_id,
"destination": dest,
},
updatevalues={
"response_code": update_row.response_code,
"response_json": None, # For now, don't persist response
}
)
if full_rows or update_delivered:
yield self.runInteraction("_persist_in_mem_txns", f)
except:
logger.exception("Failed to persist transactions!")
def _cleanup_transactions(self): def _cleanup_transactions(self):
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)

View File

@ -34,7 +34,7 @@ class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
TODO(paul): Also move the sleep() functionallity into it TODO(paul): Also move the sleep() functionality into it
""" """
def time(self): def time(self):
@ -46,6 +46,14 @@ class Clock(object):
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call(self, f, msec): def looping_call(self, f, msec):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
"""
l = task.LoopingCall(f) l = task.LoopingCall(f)
l.start(msec / 1000.0, now=False) l.start(msec / 1000.0, now=False)
return l return l

View File

@ -0,0 +1,69 @@
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
from twisted.internet import defer
from mock import Mock, call
from tests import unittest
from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(
cb.call_args_list,
[call("an arg",), call("an arg",)]
)