Merge branch 'release-v0.18.4' of github.com:matrix-org/synapse
This commit is contained in:
commit
f9834a3d1a
28
CHANGES.rst
28
CHANGES.rst
|
@ -1,3 +1,31 @@
|
|||
Changes in synapse v0.18.4 (2016-11-22)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Add workaround for buggy clients that the fail to register (PR #1632)
|
||||
|
||||
|
||||
Changes in synapse v0.18.4-rc1 (2016-11-14)
|
||||
===========================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Various database efficiency improvements (PR #1188, #1192)
|
||||
* Update default config to blacklist more internal IPs, thanks to Euan Kemp (PR
|
||||
#1198)
|
||||
* Allow specifying duration in minutes in config, thanks to Daniel Dent (PR
|
||||
#1625)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix media repo to set CORs headers on responses (PR #1190)
|
||||
* Fix registration to not error on non-ascii passwords (PR #1191)
|
||||
* Fix create event code to limit the number of prev_events (PR #1615)
|
||||
* Fix bug in transaction ID deduplication (PR #1624)
|
||||
|
||||
|
||||
Changes in synapse v0.18.3 (2016-11-08)
|
||||
=======================================
|
||||
|
||||
|
|
|
@ -51,9 +51,9 @@ python_gc_counts reactor_gc_counts
|
|||
|
||||
The twisted-specific reactor metrics have been renamed.
|
||||
|
||||
==================================== =================
|
||||
==================================== =====================
|
||||
New name Old name
|
||||
------------------------------------ -----------------
|
||||
python_twisted_reactor_pending_calls reactor_tick_time
|
||||
------------------------------------ ---------------------
|
||||
python_twisted_reactor_pending_calls reactor_pending_calls
|
||||
python_twisted_reactor_tick_time reactor_tick_time
|
||||
==================================== =================
|
||||
==================================== =====================
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.18.3"
|
||||
__version__ = "0.18.4"
|
||||
|
|
|
@ -34,6 +34,8 @@ from synapse.util.manhole import manhole
|
|||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from synapse import events
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
|
@ -151,6 +153,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.notify_appservices:
|
||||
|
|
|
@ -41,6 +41,8 @@ from synapse.util.rlimit import change_resource_limit
|
|||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
from synapse import events
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -165,6 +167,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
|
|
@ -39,6 +39,8 @@ from synapse.api.urls import FEDERATION_PREFIX
|
|||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
from synapse import events
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -156,6 +158,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
|
|
@ -41,6 +41,8 @@ from synapse.api.urls import (
|
|||
)
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
from synapse import events
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -162,6 +164,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
|
|
@ -36,6 +36,8 @@ from synapse.util.manhole import manhole
|
|||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from synapse import events
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
|
@ -239,6 +241,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
if config.start_pushers:
|
||||
sys.stderr.write(
|
||||
"\nThe pushers must be disabled in the main synapse process"
|
||||
|
|
|
@ -446,6 +446,8 @@ def start(config_options):
|
|||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = SynchrotronServer(
|
||||
|
|
|
@ -64,11 +64,12 @@ class Config(object):
|
|||
if isinstance(value, int) or isinstance(value, long):
|
||||
return value
|
||||
second = 1000
|
||||
hour = 60 * 60 * second
|
||||
minute = 60 * second
|
||||
hour = 60 * minute
|
||||
day = 24 * hour
|
||||
week = 7 * day
|
||||
year = 365 * day
|
||||
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
|
||||
sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year}
|
||||
size = 1
|
||||
suffix = value[-1]
|
||||
if suffix in sizes:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
import importlib
|
||||
|
||||
|
@ -39,7 +39,12 @@ class PasswordAuthProviderConfig(Config):
|
|||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
|
|
|
@ -167,6 +167,8 @@ class ContentRepositoryConfig(Config):
|
|||
# - '10.0.0.0/8'
|
||||
# - '172.16.0.0/12'
|
||||
# - '192.168.0.0/16'
|
||||
# - '100.64.0.0/10'
|
||||
# - '169.254.0.0/16'
|
||||
#
|
||||
# List of IP address CIDR ranges that the URL preview spider is allowed
|
||||
# to access even if they are specified in url_preview_ip_range_blacklist.
|
||||
|
|
|
@ -653,7 +653,7 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
Hashed password (str).
|
||||
"""
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
|
|
|
@ -34,6 +34,7 @@ from ._base import BaseHandler
|
|||
from canonicaljson import encode_canonical_json
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -415,6 +416,20 @@ class MessageHandler(BaseHandler):
|
|||
builder.room_id,
|
||||
)
|
||||
|
||||
# We want to limit the max number of prev events we point to in our
|
||||
# new event
|
||||
if len(latest_ret) > 10:
|
||||
# Sort by reverse depth, so we point to the most recent.
|
||||
latest_ret.sort(key=lambda a: -a[2])
|
||||
new_latest_ret = latest_ret[:5]
|
||||
|
||||
# We also randomly point to some of the older events, to make
|
||||
# sure that we don't completely ignore the older events.
|
||||
if latest_ret[5:]:
|
||||
sample_size = min(5, len(latest_ret[5:]))
|
||||
new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
|
||||
latest_ret = new_latest_ret
|
||||
|
||||
if latest_ret:
|
||||
depth = max([d for _, _, d in latest_ret]) + 1
|
||||
else:
|
||||
|
|
|
@ -392,17 +392,30 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
|
|||
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
|
||||
|
||||
if send_cors:
|
||||
request.setHeader("Access-Control-Allow-Origin", "*")
|
||||
request.setHeader("Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, OPTIONS")
|
||||
request.setHeader("Access-Control-Allow-Headers",
|
||||
"Origin, X-Requested-With, Content-Type, Accept")
|
||||
set_cors_headers(request)
|
||||
|
||||
request.write(json_bytes)
|
||||
finish_request(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
|
||||
def set_cors_headers(request):
|
||||
"""Set the CORs headers so that javascript running in a web browsers can
|
||||
use this API
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The http request to add CORs to.
|
||||
"""
|
||||
request.setHeader("Access-Control-Allow-Origin", "*")
|
||||
request.setHeader(
|
||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
request.setHeader(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, X-Requested-With, Content-Type, Accept"
|
||||
)
|
||||
|
||||
|
||||
def finish_request(request):
|
||||
""" Finish writing the response to the request.
|
||||
|
||||
|
|
|
@ -111,18 +111,20 @@ def render_all():
|
|||
return "\n".join(strs)
|
||||
|
||||
|
||||
reactor_metrics = get_metrics_for("reactor")
|
||||
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||
register_process_collector(get_metrics_for("process"))
|
||||
|
||||
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
|
||||
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
|
||||
|
||||
reactor_metrics.register_callback(
|
||||
python_metrics = get_metrics_for("python")
|
||||
|
||||
gc_time = python_metrics.register_distribution("gc_time", labels=["gen"])
|
||||
gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
|
||||
python_metrics.register_callback(
|
||||
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
|
||||
)
|
||||
|
||||
register_process_collector(get_metrics_for("process"))
|
||||
reactor_metrics = get_metrics_for("python.twisted.reactor")
|
||||
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||
|
||||
|
||||
def runUntilCurrentTimer(func):
|
||||
|
|
|
@ -13,12 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Because otherwise 'resource' collides with synapse.metrics.resource
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import stat
|
||||
from resource import getrusage, RUSAGE_SELF
|
||||
|
||||
|
||||
TICKS_PER_SEC = 100
|
||||
|
@ -29,16 +24,6 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
|
|||
HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
|
||||
HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
|
||||
|
||||
TYPES = {
|
||||
stat.S_IFSOCK: "SOCK",
|
||||
stat.S_IFLNK: "LNK",
|
||||
stat.S_IFREG: "REG",
|
||||
stat.S_IFBLK: "BLK",
|
||||
stat.S_IFDIR: "DIR",
|
||||
stat.S_IFCHR: "CHR",
|
||||
stat.S_IFIFO: "FIFO",
|
||||
}
|
||||
|
||||
# Field indexes from /proc/self/stat, taken from the proc(5) manpage
|
||||
STAT_FIELDS = {
|
||||
"utime": 14,
|
||||
|
@ -49,9 +34,7 @@ STAT_FIELDS = {
|
|||
}
|
||||
|
||||
|
||||
rusage = None
|
||||
stats = {}
|
||||
fd_counts = None
|
||||
|
||||
# In order to report process_start_time_seconds we need to know the
|
||||
# machine's boot time, because the value in /proc/self/stat is relative to
|
||||
|
@ -65,9 +48,6 @@ if HAVE_PROC_STAT:
|
|||
|
||||
|
||||
def update_resource_metrics():
|
||||
global rusage
|
||||
rusage = getrusage(RUSAGE_SELF)
|
||||
|
||||
if HAVE_PROC_SELF_STAT:
|
||||
global stats
|
||||
with open("/proc/self/stat") as s:
|
||||
|
@ -80,52 +60,17 @@ def update_resource_metrics():
|
|||
# we've lost the first two fields in PID and COMMAND above
|
||||
stats[name] = int(raw_stats[index - 3])
|
||||
|
||||
global fd_counts
|
||||
fd_counts = _process_fds()
|
||||
|
||||
|
||||
def _process_fds():
|
||||
counts = {(k,): 0 for k in TYPES.values()}
|
||||
counts[("other",)] = 0
|
||||
|
||||
def _count_fds():
|
||||
# Not every OS will have a /proc/self/fd directory
|
||||
if not HAVE_PROC_SELF_FD:
|
||||
return counts
|
||||
return 0
|
||||
|
||||
for fd in os.listdir("/proc/self/fd"):
|
||||
try:
|
||||
s = os.stat("/proc/self/fd/%s" % (fd))
|
||||
fmt = stat.S_IFMT(s.st_mode)
|
||||
if fmt in TYPES:
|
||||
t = TYPES[fmt]
|
||||
else:
|
||||
t = "other"
|
||||
|
||||
counts[(t,)] += 1
|
||||
except OSError:
|
||||
# the dirh itself used by listdir() is usually missing by now
|
||||
pass
|
||||
|
||||
return counts
|
||||
return len(os.listdir("/proc/self/fd"))
|
||||
|
||||
|
||||
def register_process_collector(process_metrics):
|
||||
# Legacy synapse-invented metric names
|
||||
|
||||
resource_metrics = process_metrics.make_subspace("resource")
|
||||
|
||||
resource_metrics.register_collector(update_resource_metrics)
|
||||
|
||||
# msecs
|
||||
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
|
||||
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
|
||||
|
||||
# kilobytes
|
||||
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
|
||||
|
||||
process_metrics.register_callback("fds", _process_fds, labels=["type"])
|
||||
|
||||
# New prometheus-standard metric names
|
||||
process_metrics.register_collector(update_resource_metrics)
|
||||
|
||||
if HAVE_PROC_SELF_STAT:
|
||||
process_metrics.register_callback(
|
||||
|
@ -158,7 +103,7 @@ def register_process_collector(process_metrics):
|
|||
if HAVE_PROC_SELF_FD:
|
||||
process_metrics.register_callback(
|
||||
"open_fds",
|
||||
lambda: sum(fd_counts.values())
|
||||
lambda: _count_fds()
|
||||
)
|
||||
|
||||
if HAVE_PROC_SELF_LIMITS:
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_transaction_key(request):
|
||||
"""A helper function which returns a transaction key that can be used
|
||||
with TransactionCache for idempotent requests.
|
||||
|
||||
Idempotency is based on the returned key being the same for separate
|
||||
requests to the same endpoint. The key is formed from the HTTP request
|
||||
path and the access_token for the requesting user.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The incoming request. Must
|
||||
contain an access_token.
|
||||
Returns:
|
||||
str: A transaction key
|
||||
"""
|
||||
token = get_access_token_from_request(request)
|
||||
return request.path + "/" + token
|
||||
|
||||
|
||||
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
||||
|
||||
|
||||
class HttpTransactionCache(object):
|
||||
|
||||
def __init__(self, clock):
|
||||
self.clock = clock
|
||||
self.transactions = {
|
||||
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
||||
}
|
||||
# Try to clean entries every 30 mins. This means entries will exist
|
||||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||
|
||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||
"""A helper function for fetch_or_execute which extracts
|
||||
a transaction key from the given request.
|
||||
|
||||
See:
|
||||
fetch_or_execute
|
||||
"""
|
||||
return self.fetch_or_execute(
|
||||
get_transaction_key(request), fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
|
||||
"""Fetches the response for this transaction, or executes the given function
|
||||
to produce a response for this transaction.
|
||||
|
||||
Args:
|
||||
txn_key (str): A key to ensure idempotency should fetch_or_execute be
|
||||
called again at a later point in time.
|
||||
fn (function): A function which returns a tuple of
|
||||
(response_code, response_dict).
|
||||
*args: Arguments to pass to fn.
|
||||
**kwargs: Keyword arguments to pass to fn.
|
||||
Returns:
|
||||
Deferred which resolves to a tuple of (response_code, response_dict).
|
||||
"""
|
||||
try:
|
||||
return self.transactions[txn_key][0].observe()
|
||||
except (KeyError, IndexError):
|
||||
pass # execute the function instead.
|
||||
|
||||
deferred = fn(*args, **kwargs)
|
||||
observable = ObservableDeferred(deferred)
|
||||
self.transactions[txn_key] = (observable, self.clock.time_msec())
|
||||
return observable.observe()
|
||||
|
||||
def _cleanup(self):
|
||||
now = self.clock.time_msec()
|
||||
for key in self.transactions.keys():
|
||||
ts = self.transactions[key][1]
|
||||
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
|
||||
del self.transactions[key]
|
|
@ -18,7 +18,8 @@
|
|||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.api.urls import CLIENT_PREFIX
|
||||
from .transactions import HttpTransactionStore
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
|
||||
import re
|
||||
|
||||
import logging
|
||||
|
@ -59,4 +60,4 @@ class ClientV1RestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_v1auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
self.txns = HttpTransactionCache(hs.get_clock())
|
||||
|
|
|
@ -53,19 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
client_path_patterns("/createRoom(?:/.*)?$"),
|
||||
self.on_OPTIONS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(request)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -214,19 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
def on_GET(self, request, room_id, event_type, txn_id):
|
||||
return (200, "Not implemented")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(request, room_id, event_type, txn_id)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_type, txn_id
|
||||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing for room ID + alias joins
|
||||
|
@ -283,19 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, {"room_id": room_id}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_identifier, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(request, room_identifier, txn_id)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_identifier, txn_id
|
||||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
|
@ -537,22 +510,11 @@ class RoomForgetRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(
|
||||
request, room_id, txn_id
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, txn_id
|
||||
)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||
|
@ -623,22 +585,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
return False
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, membership_action, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(
|
||||
request, room_id, membership_action, txn_id
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, membership_action, txn_id
|
||||
)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
|
@ -669,19 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, {"event_id": event.event_id}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_id, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = yield self.on_POST(request, room_id, event_id, txn_id)
|
||||
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_id, txn_id
|
||||
)
|
||||
|
||||
|
||||
class RoomTypingRestServlet(ClientV1RestServlet):
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
|
||||
class HttpTransactionStore(object):
|
||||
|
||||
def __init__(self):
|
||||
# { key : (txn_id, response) }
|
||||
self.transactions = {}
|
||||
|
||||
def get_response(self, key, txn_id):
|
||||
"""Retrieve a response for this request.
|
||||
|
||||
Args:
|
||||
key (str): A transaction-independent key for this request. Usually
|
||||
this is a combination of the path (without the transaction id)
|
||||
and the user's access token.
|
||||
txn_id (str): The transaction ID for this request
|
||||
Returns:
|
||||
A tuple of (HTTP response code, response content) or None.
|
||||
"""
|
||||
try:
|
||||
logger.debug("get_response TxnId: %s", txn_id)
|
||||
(last_txn_id, response) = self.transactions[key]
|
||||
if txn_id == last_txn_id:
|
||||
logger.info("get_response: Returning a response for %s", txn_id)
|
||||
return response
|
||||
except KeyError:
|
||||
pass
|
||||
return None
|
||||
|
||||
def store_response(self, key, txn_id, response):
|
||||
"""Stores an HTTP response tuple.
|
||||
|
||||
Args:
|
||||
key (str): A transaction-independent key for this request. Usually
|
||||
this is a combination of the path (without the transaction id)
|
||||
and the user's access token.
|
||||
txn_id (str): The transaction ID for this request.
|
||||
response (tuple): A tuple of (HTTP response code, response content)
|
||||
"""
|
||||
logger.debug("store_response TxnId: %s", txn_id)
|
||||
self.transactions[key] = (txn_id, response)
|
||||
|
||||
def store_client_transaction(self, request, txn_id, response):
|
||||
"""Stores the request/response pair of an HTTP transaction.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The twisted HTTP request. This
|
||||
request must have the transaction ID as the last path segment.
|
||||
response (tuple): A tuple of (response code, response dict)
|
||||
txn_id (str): The transaction ID for this request.
|
||||
"""
|
||||
self.store_response(self._get_key(request), txn_id, response)
|
||||
|
||||
def get_client_transaction(self, request, txn_id):
|
||||
"""Retrieves a stored response if there was one.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The twisted HTTP request. This
|
||||
request must have the transaction ID as the last path segment.
|
||||
txn_id (str): The transaction ID for this request.
|
||||
Returns:
|
||||
The response tuple.
|
||||
Raises:
|
||||
KeyError if the transaction was not found.
|
||||
"""
|
||||
response = self.get_response(self._get_key(request), txn_id)
|
||||
if response is None:
|
||||
raise KeyError("Transaction not found.")
|
||||
return response
|
||||
|
||||
def _get_key(self, request):
|
||||
token = get_access_token_from_request(request)
|
||||
path_without_txn_id = request.path.rsplit("/", 1)[0]
|
||||
return path_without_txn_id + "/" + token
|
|
@ -169,6 +169,17 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
guest_access_token = body.get("guest_access_token", None)
|
||||
|
||||
if (
|
||||
'initial_device_display_name' in body and
|
||||
'password' not in body
|
||||
):
|
||||
# ignore 'initial_device_display_name' if sent without
|
||||
# a password to work around a client bug where it sent
|
||||
# the 'initial_device_display_name' param alone, wiping out
|
||||
# the original registration params
|
||||
logger.warn("Ignoring initial_device_display_name without password")
|
||||
del body['initial_device_display_name']
|
||||
|
||||
session_id = self.auth_handler.get_session_id(body)
|
||||
registered_user_id = None
|
||||
if session_id:
|
||||
|
|
|
@ -19,7 +19,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.http import servlet
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.rest.client.v1.transactions import HttpTransactionStore
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
|
@ -40,18 +40,16 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
super(SendToDeviceRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
self.txns = HttpTransactionCache(hs.get_clock())
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, message_type, txn_id):
|
||||
try:
|
||||
defer.returnValue(
|
||||
self.txns.get_client_transaction(request, txn_id)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self._put, request, message_type, txn_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _put(self, request, message_type, txn_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -63,7 +61,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
)
|
||||
|
||||
response = (200, {})
|
||||
self.txns.store_client_transaction(request, txn_id, response)
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from ._base import parse_media_id, respond_with_file, respond_404
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.server import request_handler
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
@ -45,6 +45,7 @@ class DownloadResource(Resource):
|
|||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'none';"
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from ._base import parse_media_id, respond_404, respond_with_file
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.servlet import parse_string, parse_integer
|
||||
from synapse.http.server import request_handler
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
@ -48,6 +48,7 @@ class ThumbnailResource(Resource):
|
|||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
width = parse_integer(request, "width")
|
||||
height = parse_integer(request, "height")
|
||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 37
|
||||
SCHEMA_VERSION = 38
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
|
|
@ -137,17 +137,8 @@ class PusherStore(SQLBaseStore):
|
|||
|
||||
@cachedInlineCallbacks(num_args=1, max_entries=15000)
|
||||
def get_if_user_has_pusher(self, user_id):
|
||||
result = yield self._simple_select_many_batch(
|
||||
table='pushers',
|
||||
keyvalues={
|
||||
'user_name': 'user_id',
|
||||
},
|
||||
retcol='user_name',
|
||||
desc='get_if_user_has_pusher',
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
defer.returnValue(bool(result))
|
||||
# This only exists for the cachedList decorator
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids", num_args=1, inlineCallbacks=True)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT into background_updates (update_name, progress_json)
|
||||
VALUES ('event_search_postgres_gist', '{}');
|
|
@ -31,6 +31,7 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
||||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(SearchStore, self).__init__(hs)
|
||||
|
@ -41,6 +42,10 @@ class SearchStore(BackgroundUpdateStore):
|
|||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
self._background_reindex_search_order
|
||||
)
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
|
||||
self._background_reindex_gist_search
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_search(self, progress, batch_size):
|
||||
|
@ -139,6 +144,28 @@ class SearchStore(BackgroundUpdateStore):
|
|||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_gist_search(self, progress, batch_size):
|
||||
def create_index(conn):
|
||||
conn.rollback()
|
||||
conn.set_session(autocommit=True)
|
||||
c = conn.cursor()
|
||||
|
||||
c.execute(
|
||||
"CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
|
||||
" ON event_search USING GIST (vector)"
|
||||
)
|
||||
|
||||
c.execute("DROP INDEX event_search_fts_idx")
|
||||
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
yield self.runWithConnection(create_index)
|
||||
|
||||
yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
|
||||
defer.returnValue(1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_search_order(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
|
|
|
@ -16,13 +16,12 @@
|
|||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
|
@ -50,20 +49,6 @@ class TransactionStore(SQLBaseStore):
|
|||
def __init__(self, hs):
|
||||
super(TransactionStore, self).__init__(hs)
|
||||
|
||||
# New transactions that are currently in flights
|
||||
self.inflight_transactions = {}
|
||||
|
||||
# Newly delievered transactions that *weren't* persisted while in flight
|
||||
self.new_delivered_transactions = {}
|
||||
|
||||
# Newly delivered transactions that *were* persisted while in flight
|
||||
self.update_delivered_transactions = {}
|
||||
|
||||
self.last_transaction = {}
|
||||
|
||||
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
|
||||
self._clock.looping_call(self._persist_in_mem_txns, 1000)
|
||||
|
||||
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
|
@ -148,46 +133,7 @@ class TransactionStore(SQLBaseStore):
|
|||
Returns:
|
||||
list: A list of previous transaction ids.
|
||||
"""
|
||||
|
||||
auto_id = self._transaction_id_gen.get_next()
|
||||
|
||||
txn_row = _TransactionRow(
|
||||
id=auto_id,
|
||||
transaction_id=transaction_id,
|
||||
destination=destination,
|
||||
ts=origin_server_ts,
|
||||
response_code=0,
|
||||
response_json=None,
|
||||
)
|
||||
|
||||
self.inflight_transactions.setdefault(destination, {})[transaction_id] = txn_row
|
||||
|
||||
prev_txn = self.last_transaction.get(destination)
|
||||
if prev_txn:
|
||||
return defer.succeed(prev_txn)
|
||||
else:
|
||||
return self.runInteraction(
|
||||
"_get_prevs_txn",
|
||||
self._get_prevs_txn,
|
||||
destination,
|
||||
)
|
||||
|
||||
def _get_prevs_txn(self, txn, destination):
|
||||
# First we find out what the prev_txns should be.
|
||||
# Since we know that we are only sending one transaction at a time,
|
||||
# we can simply take the last one.
|
||||
query = (
|
||||
"SELECT * FROM sent_transactions"
|
||||
" WHERE destination = ?"
|
||||
" ORDER BY id DESC LIMIT 1"
|
||||
)
|
||||
|
||||
txn.execute(query, (destination,))
|
||||
results = self.cursor_to_dict(txn)
|
||||
|
||||
prev_txns = [r["transaction_id"] for r in results]
|
||||
|
||||
return prev_txns
|
||||
return defer.succeed([])
|
||||
|
||||
def delivered_txn(self, transaction_id, destination, code, response_dict):
|
||||
"""Persists the response for an outgoing transaction.
|
||||
|
@ -198,52 +144,7 @@ class TransactionStore(SQLBaseStore):
|
|||
code (int)
|
||||
response_json (str)
|
||||
"""
|
||||
|
||||
txn_row = self.inflight_transactions.get(
|
||||
destination, {}
|
||||
).pop(transaction_id, None)
|
||||
|
||||
self.last_transaction[destination] = transaction_id
|
||||
|
||||
if txn_row:
|
||||
d = self.new_delivered_transactions.setdefault(destination, {})
|
||||
d[transaction_id] = txn_row._replace(
|
||||
response_code=code,
|
||||
response_json=None, # For now, don't persist response
|
||||
)
|
||||
else:
|
||||
d = self.update_delivered_transactions.setdefault(destination, {})
|
||||
# For now, don't persist response
|
||||
d[transaction_id] = _UpdateTransactionRow(code, None)
|
||||
|
||||
def get_transactions_after(self, transaction_id, destination):
|
||||
"""Get all transactions after a given local transaction_id.
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
|
||||
Returns:
|
||||
list: A list of dicts
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_transactions_after",
|
||||
self._get_transactions_after, transaction_id, destination
|
||||
)
|
||||
|
||||
def _get_transactions_after(self, txn, transaction_id, destination):
|
||||
query = (
|
||||
"SELECT * FROM sent_transactions"
|
||||
" WHERE destination = ? AND id >"
|
||||
" ("
|
||||
" SELECT id FROM sent_transactions"
|
||||
" WHERE transaction_id = ? AND destination = ?"
|
||||
" )"
|
||||
)
|
||||
|
||||
txn.execute(query, (destination, transaction_id, destination))
|
||||
|
||||
return self.cursor_to_dict(txn)
|
||||
pass
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_destination_retry_timings(self, destination):
|
||||
|
@ -339,58 +240,11 @@ class TransactionStore(SQLBaseStore):
|
|||
txn.execute(query, (self._clock.time_msec(),))
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_in_mem_txns(self):
|
||||
try:
|
||||
inflight = self.inflight_transactions
|
||||
new_delivered = self.new_delivered_transactions
|
||||
update_delivered = self.update_delivered_transactions
|
||||
|
||||
self.inflight_transactions = {}
|
||||
self.new_delivered_transactions = {}
|
||||
self.update_delivered_transactions = {}
|
||||
|
||||
full_rows = [
|
||||
row._asdict()
|
||||
for txn_map in itertools.chain(inflight.values(), new_delivered.values())
|
||||
for row in txn_map.values()
|
||||
]
|
||||
|
||||
def f(txn):
|
||||
if full_rows:
|
||||
self._simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="sent_transactions",
|
||||
values=full_rows
|
||||
)
|
||||
|
||||
for dest, txn_map in update_delivered.items():
|
||||
for txn_id, update_row in txn_map.items():
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table="sent_transactions",
|
||||
keyvalues={
|
||||
"transaction_id": txn_id,
|
||||
"destination": dest,
|
||||
},
|
||||
updatevalues={
|
||||
"response_code": update_row.response_code,
|
||||
"response_json": None, # For now, don't persist response
|
||||
}
|
||||
)
|
||||
|
||||
if full_rows or update_delivered:
|
||||
yield self.runInteraction("_persist_in_mem_txns", f)
|
||||
except:
|
||||
logger.exception("Failed to persist transactions!")
|
||||
|
||||
def _cleanup_transactions(self):
|
||||
now = self._clock.time_msec()
|
||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||
six_hours_ago = now - 6 * 60 * 60 * 1000
|
||||
|
||||
def _cleanup_transactions_txn(txn):
|
||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
|
||||
|
||||
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
||||
return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
|
||||
|
|
|
@ -34,7 +34,7 @@ class Clock(object):
|
|||
"""A small utility that obtains current time-of-day so that time may be
|
||||
mocked during unit-tests.
|
||||
|
||||
TODO(paul): Also move the sleep() functionallity into it
|
||||
TODO(paul): Also move the sleep() functionality into it
|
||||
"""
|
||||
|
||||
def time(self):
|
||||
|
@ -46,6 +46,14 @@ class Clock(object):
|
|||
return int(self.time() * 1000)
|
||||
|
||||
def looping_call(self, f, msec):
|
||||
"""Call a function repeatedly.
|
||||
|
||||
Waits `msec` initially before calling `f` for the first time.
|
||||
|
||||
Args:
|
||||
f(function): The function to call repeatedly.
|
||||
msec(float): How long to wait between calls in milliseconds.
|
||||
"""
|
||||
l = task.LoopingCall(f)
|
||||
l.start(msec / 1000.0, now=False)
|
||||
return l
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
|
||||
from twisted.internet import defer
|
||||
from mock import Mock, call
|
||||
from tests import unittest
|
||||
from tests.utils import MockClock
|
||||
|
||||
|
||||
class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.cache = HttpTransactionCache(self.clock)
|
||||
|
||||
self.mock_http_response = (200, "GOOD JOB!")
|
||||
self.mock_key = "foo"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_executes_given_function(self):
|
||||
cb = Mock(
|
||||
return_value=defer.succeed(self.mock_http_response)
|
||||
)
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg"
|
||||
)
|
||||
cb.assert_called_once_with("some_arg", keyword="arg")
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_deduplicates_based_on_key(self):
|
||||
cb = Mock(
|
||||
return_value=defer.succeed(self.mock_http_response)
|
||||
)
|
||||
for i in range(3): # invoke multiple times
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
||||
)
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
# expect only a single call to do the work
|
||||
cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cleans_up(self):
|
||||
cb = Mock(
|
||||
return_value=defer.succeed(self.mock_http_response)
|
||||
)
|
||||
yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "an arg"
|
||||
)
|
||||
# should NOT have cleaned up yet
|
||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
||||
|
||||
yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "an arg"
|
||||
)
|
||||
# still using cache
|
||||
cb.assert_called_once_with("an arg")
|
||||
|
||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
|
||||
|
||||
yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "an arg"
|
||||
)
|
||||
# no longer using cache
|
||||
self.assertEqual(cb.call_count, 2)
|
||||
self.assertEqual(
|
||||
cb.call_args_list,
|
||||
[call("an arg",), call("an arg",)]
|
||||
)
|
Loading…
Reference in New Issue