Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fed_reader
This commit is contained in:
commit
5aa024e501
64
CHANGES.rst
64
CHANGES.rst
|
@ -1,3 +1,67 @@
|
||||||
|
Changes in synapse v0.17.0-rc1 (2016-07-28)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
This release changes the LDAP configuration format in a backwards incompatible
|
||||||
|
way, see PR #843 for details.
|
||||||
|
|
||||||
|
This release contains significant security bug fixes regarding authenticating
|
||||||
|
events received over federation. Please upgrade.
|
||||||
|
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Add purge_media_cache admin API (PR #902)
|
||||||
|
* Add deactivate account admin API (PR #903)
|
||||||
|
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
|
||||||
|
* Add an admin option to shared secret registration (breaks backwards compat)
|
||||||
|
(PR #909)
|
||||||
|
* Add purge local room history API (PR #911, #923, #924)
|
||||||
|
* Add requestToken endpoints (PR #915)
|
||||||
|
* Add an /account/deactivate endpoint (PR #921)
|
||||||
|
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
|
||||||
|
* Add device_id support to /login (PR #929)
|
||||||
|
* Add device_id support to /v2/register flow. (PR #937, #942)
|
||||||
|
* Add GET /devices endpoint (PR #939, #944)
|
||||||
|
* Add GET /device/{deviceId} (PR #943)
|
||||||
|
* Add update and delete APIs for devices (PR #949)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
|
||||||
|
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
|
||||||
|
* Remove the legacy v0 content upload API. (PR #888)
|
||||||
|
* Use similar naming we use in email notifs for push (PR #894)
|
||||||
|
* Optionally include password hash in createUser endpoint (PR #905 by
|
||||||
|
KentShikama)
|
||||||
|
* Use a query that postgresql optimises better for get_events_around (PR #906)
|
||||||
|
* Fall back to 'username' if 'user' is not given for appservice registration.
|
||||||
|
(PR #927 by Half-Shot)
|
||||||
|
* Add metrics for psutil derived memory usage (PR #936)
|
||||||
|
* Record device_id in client_ips (PR #938)
|
||||||
|
* Send the correct host header when fetching keys (PR #941)
|
||||||
|
* Log the hostname the reCAPTCHA was completed on (PR #946)
|
||||||
|
* Make the device id on e2e key upload optional (PR #956)
|
||||||
|
* Add r0.2.0 to the "supported versions" list (PR #960)
|
||||||
|
* Don't include name of room for invites in push (PR #961)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix substitution failure in mail template (PR #887)
|
||||||
|
* Put most recent 20 messages in email notif (PR #892)
|
||||||
|
* Ensure that the guest user is in the database when upgrading accounts
|
||||||
|
(PR #914)
|
||||||
|
* Fix various edge cases in auth handling (PR #919)
|
||||||
|
* Fix 500 ISE when sending alias event without a state_key (PR #925)
|
||||||
|
* Fix bug where we stored rejections in the state_group, persist all
|
||||||
|
rejections (PR #948)
|
||||||
|
* Fix lack of check of if the user is banned when handling 3pid invites
|
||||||
|
(PR #952)
|
||||||
|
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||||
==========================================
|
==========================================
|
||||||
|
|
||||||
|
|
|
@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
|
||||||
IDs:
|
IDs:
|
||||||
|
|
||||||
1) Use the machine's own hostname as available on public DNS in the form of
|
1) Use the machine's own hostname as available on public DNS in the form of
|
||||||
its A or AAAA records. This is easier to set up initially, perhaps for
|
its A records. This is easier to set up initially, perhaps for
|
||||||
testing, but lacks the flexibility of SRV.
|
testing, but lacks the flexibility of SRV.
|
||||||
|
|
||||||
2) Set up a SRV record for your domain name. This requires you create a SRV
|
2) Set up a SRV record for your domain name. This requires you create a SRV
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
Admin APIs
|
||||||
|
==========
|
||||||
|
|
||||||
|
This directory includes documentation for the various synapse specific admin
|
||||||
|
APIs available.
|
||||||
|
|
||||||
|
Only users that are server admins can use these APIs. A user can be marked as a
|
||||||
|
server admin by updating the database directly, e.g.:
|
||||||
|
|
||||||
|
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
|
||||||
|
|
||||||
|
Restarting may be required for the changes to register.
|
|
@ -0,0 +1,15 @@
|
||||||
|
Purge History API
|
||||||
|
=================
|
||||||
|
|
||||||
|
The purge history API allows server admins to purge historic events from their
|
||||||
|
database, reclaiming disk space.
|
||||||
|
|
||||||
|
Depending on the amount of history being purged a call to the API may take
|
||||||
|
several minutes or longer. During this period users will not be able to
|
||||||
|
paginate further back in the room from the point being purged from.
|
||||||
|
|
||||||
|
The API is simply:
|
||||||
|
|
||||||
|
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||||
|
|
||||||
|
including an ``access_token`` of a server admin.
|
|
@ -0,0 +1,19 @@
|
||||||
|
Purge Remote Media API
|
||||||
|
======================
|
||||||
|
|
||||||
|
The purge remote media API allows server admins to purge old cached remote
|
||||||
|
media.
|
||||||
|
|
||||||
|
The API is::
|
||||||
|
|
||||||
|
POST /_matrix/client/r0/admin/purge_media_cache
|
||||||
|
|
||||||
|
{
|
||||||
|
"before_ts": <unix_timestamp_in_ms>
|
||||||
|
}
|
||||||
|
|
||||||
|
Which will remove all cached media that was last accessed before
|
||||||
|
``<unix_timestamp_in_ms>``.
|
||||||
|
|
||||||
|
If the user re-requests purged remote media, synapse will re-request the media
|
||||||
|
from the originating server.
|
|
@ -16,7 +16,5 @@ ignore =
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 90
|
max-line-length = 90
|
||||||
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||||
|
ignore = W503
|
||||||
[pep8]
|
|
||||||
max-line-length = 90
|
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.16.1-r1"
|
__version__ = "0.17.0-rc1"
|
||||||
|
|
|
@ -13,22 +13,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
|
||||||
from synapse.types import Requester, UserID, get_domain_from_id
|
|
||||||
from synapse.util.logutils import log_function
|
|
||||||
from synapse.util.logcontext import preserve_context_over_fn
|
|
||||||
from synapse.util.metrics import Measure
|
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
import logging
|
import synapse.types
|
||||||
import pymacaroons
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
|
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||||
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -376,6 +376,10 @@ class Auth(object):
|
||||||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||||
if not self._verify_third_party_invite(event, auth_events):
|
if not self._verify_third_party_invite(event, auth_events):
|
||||||
raise AuthError(403, "You are not invited to this room.")
|
raise AuthError(403, "You are not invited to this room.")
|
||||||
|
if target_banned:
|
||||||
|
raise AuthError(
|
||||||
|
403, "%s is banned from the room" % (target_user_id,)
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if Membership.JOIN != membership:
|
if Membership.JOIN != membership:
|
||||||
|
@ -566,8 +570,7 @@ class Auth(object):
|
||||||
Args:
|
Args:
|
||||||
request - An HTTP request with an access_token query parameter.
|
request - An HTTP request with an access_token query parameter.
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to a namedtuple including "user" (UserID)
|
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||||
"access_token_id" (int), "is_guest" (bool)
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
|
@ -576,9 +579,7 @@ class Auth(object):
|
||||||
user_id = yield self._get_appservice_user_id(request.args)
|
user_id = yield self._get_appservice_user_id(request.args)
|
||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
defer.returnValue(
|
defer.returnValue(synapse.types.create_requester(user_id))
|
||||||
Requester(UserID.from_string(user_id), "", False)
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token = request.args["access_token"][0]
|
access_token = request.args["access_token"][0]
|
||||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||||
|
@ -612,7 +613,8 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user.to_string()
|
||||||
|
|
||||||
defer.returnValue(Requester(user, token_id, is_guest))
|
defer.returnValue(synapse.types.create_requester(
|
||||||
|
user, token_id, is_guest, device_id))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||||
|
|
|
@ -16,13 +16,11 @@
|
||||||
import sys
|
import sys
|
||||||
sys.dont_write_bytecode = True
|
sys.dont_write_bytecode = True
|
||||||
|
|
||||||
from synapse.python_dependencies import (
|
from synapse import python_dependencies # noqa: E402
|
||||||
check_requirements, MissingRequirementError
|
|
||||||
) # NOQA
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_requirements()
|
python_dependencies.check_requirements()
|
||||||
except MissingRequirementError as e:
|
except python_dependencies.MissingRequirementError as e:
|
||||||
message = "\n".join([
|
message = "\n".join([
|
||||||
"Missing Requirement: %s" % (e.message,),
|
"Missing Requirement: %s" % (e.message,),
|
||||||
"To install run:",
|
"To install run:",
|
||||||
|
|
|
@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.remote_key = defer.Deferred()
|
self.remote_key = defer.Deferred()
|
||||||
self.host = None
|
self.host = None
|
||||||
|
self._peer = None
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
self.host = self.transport.getHost()
|
self._peer = self.transport.getPeer()
|
||||||
logger.debug("Connected to %s", self.host)
|
logger.debug("Connected to %s", self._peer)
|
||||||
|
|
||||||
self.sendCommand(b"GET", self.path)
|
self.sendCommand(b"GET", self.path)
|
||||||
if self.host:
|
if self.host:
|
||||||
self.sendHeader(b"Host", self.host)
|
self.sendHeader(b"Host", self.host)
|
||||||
|
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||||
self.timer.cancel()
|
self.timer.cancel()
|
||||||
|
|
||||||
def on_timeout(self):
|
def on_timeout(self):
|
||||||
logger.debug("Timeout waiting for response from %s", self.host)
|
logger.debug(
|
||||||
|
"Timeout waiting for response from %s: %s",
|
||||||
|
self.host, self._peer,
|
||||||
|
)
|
||||||
self.errback(IOError("Timeout waiting for response"))
|
self.errback(IOError("Timeout waiting for response"))
|
||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
|
||||||
def protocol(self):
|
def protocol(self):
|
||||||
protocol = SynapseKeyClientProtocol()
|
protocol = SynapseKeyClientProtocol()
|
||||||
protocol.path = self.path
|
protocol.path = self.path
|
||||||
|
protocol.host = self.host
|
||||||
return protocol
|
return protocol
|
||||||
|
|
|
@ -44,7 +44,21 @@ import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||||
|
"server_name", "key_ids", "json_object", "deferred"
|
||||||
|
))
|
||||||
|
"""
|
||||||
|
A request for a verify key to verify a JSON object.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
server_name(str): The name of the server to verify against.
|
||||||
|
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||||
|
JSON object
|
||||||
|
json_object(dict): The JSON object to verify.
|
||||||
|
deferred(twisted.internet.defer.Deferred):
|
||||||
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
|
a verify key has been fetched
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
|
@ -74,39 +88,32 @@ class Keyring(object):
|
||||||
list of deferreds indicating success or failure to verify each
|
list of deferreds indicating success or failure to verify each
|
||||||
json object's signature for the given server_name.
|
json object's signature for the given server_name.
|
||||||
"""
|
"""
|
||||||
group_id_to_json = {}
|
verify_requests = []
|
||||||
group_id_to_group = {}
|
|
||||||
group_ids = []
|
|
||||||
|
|
||||||
next_group_id = 0
|
|
||||||
deferreds = {}
|
|
||||||
|
|
||||||
for server_name, json_object in server_and_json:
|
for server_name, json_object in server_and_json:
|
||||||
logger.debug("Verifying for %s", server_name)
|
logger.debug("Verifying for %s", server_name)
|
||||||
group_id = next_group_id
|
|
||||||
next_group_id += 1
|
|
||||||
group_ids.append(group_id)
|
|
||||||
|
|
||||||
key_ids = signature_ids(json_object, server_name)
|
key_ids = signature_ids(json_object, server_name)
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
deferreds[group_id] = defer.fail(SynapseError(
|
deferred = defer.fail(SynapseError(
|
||||||
400,
|
400,
|
||||||
"Not signed with a supported algorithm",
|
"Not signed with a supported algorithm",
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
deferreds[group_id] = defer.Deferred()
|
deferred = defer.Deferred()
|
||||||
|
|
||||||
group = KeyGroup(server_name, group_id, key_ids)
|
verify_request = VerifyKeyRequest(
|
||||||
|
server_name, key_ids, json_object, deferred
|
||||||
|
)
|
||||||
|
|
||||||
group_id_to_group[group_id] = group
|
verify_requests.append(verify_request)
|
||||||
group_id_to_json[group_id] = json_object
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_key_deferred(group, deferred):
|
def handle_key_deferred(verify_request):
|
||||||
server_name = group.server_name
|
server_name = verify_request.server_name
|
||||||
try:
|
try:
|
||||||
_, _, key_id, verify_key = yield deferred
|
_, key_id, verify_key = yield verify_request.deferred
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
|
@ -128,7 +135,7 @@ class Keyring(object):
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
json_object = group_id_to_json[group.group_id]
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
@ -157,36 +164,34 @@ class Keyring(object):
|
||||||
|
|
||||||
# Actually start fetching keys.
|
# Actually start fetching keys.
|
||||||
wait_on_deferred.addBoth(
|
wait_on_deferred.addBoth(
|
||||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
lambda _: self.get_server_verify_keys(verify_requests)
|
||||||
)
|
)
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
# any lookups waiting will proceed.
|
# any lookups waiting will proceed.
|
||||||
server_to_gids = {}
|
server_to_request_ids = {}
|
||||||
|
|
||||||
def remove_deferreds(res, server_name, group_id):
|
def remove_deferreds(res, server_name, verify_request):
|
||||||
server_to_gids[server_name].discard(group_id)
|
request_id = id(verify_request)
|
||||||
if not server_to_gids[server_name]:
|
server_to_request_ids[server_name].discard(request_id)
|
||||||
|
if not server_to_request_ids[server_name]:
|
||||||
d = server_to_deferred.pop(server_name, None)
|
d = server_to_deferred.pop(server_name, None)
|
||||||
if d:
|
if d:
|
||||||
d.callback(None)
|
d.callback(None)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
for g_id, deferred in deferreds.items():
|
for verify_request in verify_requests:
|
||||||
server_name = group_id_to_group[g_id].server_name
|
server_name = verify_request.server_name
|
||||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
request_id = id(verify_request)
|
||||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||||
|
|
||||||
# Pass those keys to handle_key_deferred so that the json object
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
# signatures can be verified
|
# signatures can be verified
|
||||||
return [
|
return [
|
||||||
preserve_context_over_fn(
|
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||||
handle_key_deferred,
|
for verify_request in verify_requests
|
||||||
group_id_to_group[g_id],
|
|
||||||
deferreds[g_id],
|
|
||||||
)
|
|
||||||
for g_id in group_ids
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -220,7 +225,7 @@ class Keyring(object):
|
||||||
|
|
||||||
d.addBoth(rm, server_name)
|
d.addBoth(rm, server_name)
|
||||||
|
|
||||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
def get_server_verify_keys(self, verify_requests):
|
||||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
each group.
|
each group.
|
||||||
"""
|
"""
|
||||||
|
@ -237,62 +242,64 @@ class Keyring(object):
|
||||||
merged_results = {}
|
merged_results = {}
|
||||||
|
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
for verify_request in verify_requests:
|
||||||
missing_keys.setdefault(group.server_name, set()).update(
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
group.key_ids
|
verify_request.key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for fn in key_fetch_fns:
|
for fn in key_fetch_fns:
|
||||||
results = yield fn(missing_keys.items())
|
results = yield fn(missing_keys.items())
|
||||||
merged_results.update(results)
|
merged_results.update(results)
|
||||||
|
|
||||||
# We now need to figure out which groups we have keys for
|
# We now need to figure out which verify requests we have keys
|
||||||
# and which we don't
|
# for and which we don't
|
||||||
missing_groups = {}
|
missing_keys = {}
|
||||||
for group in group_id_to_group.values():
|
requests_missing_keys = []
|
||||||
for key_id in group.key_ids:
|
for verify_request in verify_requests:
|
||||||
if key_id in merged_results[group.server_name]:
|
server_name = verify_request.server_name
|
||||||
|
result_keys = merged_results[server_name]
|
||||||
|
|
||||||
|
if verify_request.deferred.called:
|
||||||
|
# We've already called this deferred, which probably
|
||||||
|
# means that we've already found a key for it.
|
||||||
|
continue
|
||||||
|
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
if key_id in result_keys:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
group_id_to_deferred[group.group_id].callback((
|
verify_request.deferred.callback((
|
||||||
group.group_id,
|
server_name,
|
||||||
group.server_name,
|
|
||||||
key_id,
|
key_id,
|
||||||
merged_results[group.server_name][key_id],
|
result_keys[key_id],
|
||||||
))
|
))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
missing_groups.setdefault(
|
# The else block is only reached if the loop above
|
||||||
group.server_name, []
|
# doesn't break.
|
||||||
).append(group)
|
missing_keys.setdefault(server_name, set()).update(
|
||||||
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
requests_missing_keys.append(verify_request)
|
||||||
|
|
||||||
if not missing_groups:
|
if not missing_keys:
|
||||||
break
|
break
|
||||||
|
|
||||||
missing_keys = {
|
for verify_request in requests_missing_keys.values():
|
||||||
server_name: set(
|
verify_request.deferred.errback(SynapseError(
|
||||||
key_id for group in groups for key_id in group.key_ids
|
|
||||||
)
|
|
||||||
for server_name, groups in missing_groups.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
for group in missing_groups.values():
|
|
||||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
|
||||||
401,
|
401,
|
||||||
"No key for %s with id %s" % (
|
"No key for %s with id %s" % (
|
||||||
group.server_name, group.key_ids,
|
verify_request.server_name, verify_request.key_ids,
|
||||||
),
|
),
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
for deferred in group_id_to_deferred.values():
|
for verify_request in verify_requests:
|
||||||
if not deferred.called:
|
if not verify_request.deferred.called:
|
||||||
deferred.errback(err)
|
verify_request.deferred.errback(err)
|
||||||
|
|
||||||
do_iterations().addErrback(on_err)
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
return group_id_to_deferred
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
res = yield defer.gatherResults(
|
res = yield defer.gatherResults(
|
||||||
|
@ -447,7 +454,7 @@ class Keyring(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_response = yield self.process_v2_response(
|
processed_response = yield self.process_v2_response(
|
||||||
perspective_name, response
|
perspective_name, response, only_from_server=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for server_name, response_keys in processed_response.items():
|
for server_name, response_keys in processed_response.items():
|
||||||
|
@ -527,7 +534,7 @@ class Keyring(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_v2_response(self, from_server, response_json,
|
def process_v2_response(self, from_server, response_json,
|
||||||
requested_ids=[]):
|
requested_ids=[], only_from_server=True):
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
response_keys = {}
|
response_keys = {}
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
|
@ -551,6 +558,13 @@ class Keyring(object):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
server_name = response_json["server_name"]
|
server_name = response_json["server_name"]
|
||||||
|
if only_from_server:
|
||||||
|
if server_name != from_server:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected a response for server %r not %r" % (
|
||||||
|
from_server, server_name
|
||||||
|
)
|
||||||
|
)
|
||||||
for key_id in response_json["signatures"].get(server_name, {}):
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import LimitExceededError
|
import synapse.types
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID, Requester
|
from synapse.api.errors import LimitExceededError
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -124,7 +124,8 @@ class BaseHandler(object):
|
||||||
# and having homeservers have their own users leave keeps more
|
# and having homeservers have their own users leave keeps more
|
||||||
# of that decision-making and control local to the guest-having
|
# of that decision-making and control local to the guest-having
|
||||||
# homeserver.
|
# homeserver.
|
||||||
requester = Requester(target_user, "", True)
|
requester = synapse.types.create_requester(
|
||||||
|
target_user, is_guest=True)
|
||||||
handler = self.hs.get_handlers().room_member_handler
|
handler = self.hs.get_handlers().room_member_handler
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
|
|
|
@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
|
||||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -279,7 +280,16 @@ class AuthHandler(BaseHandler):
|
||||||
data = pde.response
|
data = pde.response
|
||||||
resp_body = simplejson.loads(data)
|
resp_body = simplejson.loads(data)
|
||||||
|
|
||||||
if 'success' in resp_body and resp_body['success']:
|
if 'success' in resp_body:
|
||||||
|
# Note that we do NOT check the hostname here: we explicitly
|
||||||
|
# intend the CAPTCHA to be presented by whatever client the
|
||||||
|
# user is using, we just care that they have completed a CAPTCHA.
|
||||||
|
logger.info(
|
||||||
|
"%s reCAPTCHA from hostname %s",
|
||||||
|
"Successful" if resp_body['success'] else "Failed",
|
||||||
|
resp_body.get('hostname')
|
||||||
|
)
|
||||||
|
if resp_body['success']:
|
||||||
defer.returnValue(True)
|
defer.returnValue(True)
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
@ -365,7 +375,8 @@ class AuthHandler(BaseHandler):
|
||||||
return self._check_password(user_id, password)
|
return self._check_password(user_id, password)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_login_tuple_for_user_id(self, user_id, device_id=None):
|
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||||
|
initial_display_name=None):
|
||||||
"""
|
"""
|
||||||
Gets login tuple for the user with the given user ID.
|
Gets login tuple for the user with the given user ID.
|
||||||
|
|
||||||
|
@ -374,9 +385,15 @@ class AuthHandler(BaseHandler):
|
||||||
The user is assumed to have been authenticated by some other
|
The user is assumed to have been authenticated by some other
|
||||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||||
|
|
||||||
|
The device will be recorded in the table if it is not there already.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): canonical User ID
|
user_id (str): canonical User ID
|
||||||
device_id (str): the device ID to associate with the access token
|
device_id (str|None): the device ID to associate with the tokens.
|
||||||
|
None to leave the tokens unassociated with a device (deprecated:
|
||||||
|
we should always have a device ID)
|
||||||
|
initial_display_name (str): display name to associate with the
|
||||||
|
device if it needs re-registering
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
|
@ -388,6 +405,16 @@ class AuthHandler(BaseHandler):
|
||||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||||
access_token = yield self.issue_access_token(user_id, device_id)
|
access_token = yield self.issue_access_token(user_id, device_id)
|
||||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||||
|
|
||||||
|
# the device *should* have been registered before we got here; however,
|
||||||
|
# it's possible we raced against a DELETE operation. The thing we
|
||||||
|
# really don't want is active access_tokens without a record of the
|
||||||
|
# device, so we double-check it here.
|
||||||
|
if device_id is not None:
|
||||||
|
yield self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((access_token, refresh_token))
|
defer.returnValue((access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler):
|
||||||
Args:
|
Args:
|
||||||
user_id (str):
|
user_id (str):
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: dict[str, dict[str, X]]: map from device_id to
|
defer.Deferred: list[dict[str, X]]: info on each device
|
||||||
info on the device
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
devices = yield self.store.get_devices_by_user(user_id)
|
device_map = yield self.store.get_devices_by_user(user_id)
|
||||||
|
|
||||||
ips = yield self.store.get_last_client_ip_by_device(
|
ips = yield self.store.get_last_client_ip_by_device(
|
||||||
devices=((user_id, device_id) for device_id in devices.keys())
|
devices=((user_id, device_id) for device_id in device_map.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
for device in devices.values():
|
devices = device_map.values()
|
||||||
|
for device in devices:
|
||||||
_update_device_from_client_ips(device, ips)
|
_update_device_from_client_ips(device, ips)
|
||||||
|
|
||||||
defer.returnValue(devices)
|
defer.returnValue(devices)
|
||||||
|
@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str):
|
user_id (str):
|
||||||
device_id (str)
|
device_id (str):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: dict[str, X]: info on the device
|
defer.Deferred: dict[str, X]: info on the device
|
||||||
|
@ -117,6 +117,61 @@ class DeviceHandler(BaseHandler):
|
||||||
_update_device_from_client_ips(device, ips)
|
_update_device_from_client_ips(device, ips)
|
||||||
defer.returnValue(device)
|
defer.returnValue(device)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
""" Delete the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.delete_device(user_id, device_id)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
# no match
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield self.store.user_delete_access_tokens(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
delete_refresh_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.delete_e2e_keys_by_device(
|
||||||
|
user_id=user_id, device_id=device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_device(self, user_id, device_id, content):
|
||||||
|
""" Update the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
content (dict): body of update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.update_device(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
new_display_name=content.get("display_name")
|
||||||
|
)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
raise errors.NotFoundError()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _update_device_from_client_ips(device, client_ips):
|
def _update_device_from_client_ips(device, client_ips):
|
||||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||||
|
|
|
@ -13,15 +13,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||||
from synapse.types import UserID, Requester
|
from synapse.types import UserID
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# Assume the user isn't a guest because we don't let guests set
|
# Assume the user isn't a guest because we don't let guests set
|
||||||
# profile or avatar data.
|
# profile or avatar data.
|
||||||
requester = Requester(user, "", False)
|
# XXX why are we recreating `requester` here for each room?
|
||||||
|
# what was wrong with the `requester` we were passed?
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield handler.update_membership(
|
yield handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
user,
|
user,
|
||||||
|
|
|
@ -14,18 +14,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Contains functions for registering clients."""
|
"""Contains functions for registering clients."""
|
||||||
|
import logging
|
||||||
|
import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID, Requester
|
import synapse.types
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
|
||||||
from synapse.util.async import run_on_reactor
|
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
from synapse.types import UserID
|
||||||
import logging
|
from synapse.util.async import run_on_reactor
|
||||||
import urllib
|
from ._base import BaseHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
|
requester = synapse.types.create_requester(user)
|
||||||
yield profile_handler.set_displayname(
|
yield profile_handler.set_displayname(
|
||||||
user, Requester(user, token, False), displayname
|
user, requester, displayname
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
|
@ -14,24 +14,22 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
from signedjson.sign import verify_signed_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
from ._base import BaseHandler
|
import synapse.types
|
||||||
|
|
||||||
from synapse.types import UserID, RoomID, Requester
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes, Membership,
|
EventTypes, Membership,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||||
|
from synapse.types import UserID, RoomID
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.util.distributor import user_left_room, user_joined_room
|
from synapse.util.distributor import user_left_room, user_joined_room
|
||||||
|
from ._base import BaseHandler
|
||||||
from signedjson.sign import verify_signed_json
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
|
||||||
|
|
||||||
from unpaddedbase64 import decode_base64
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||||
else:
|
else:
|
||||||
requester = Requester(target_user, None, False)
|
requester = synapse.types.create_requester(target_user)
|
||||||
|
|
||||||
message_handler = self.hs.get_handlers().message_handler
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
|
|
||||||
def register_paths(self, method, path_patterns, callback):
|
def register_paths(self, method, path_patterns, callback):
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
self._PathEntry(path_pattern, callback)
|
self._PathEntry(path_pattern, callback)
|
||||||
)
|
)
|
||||||
|
|
|
@ -140,9 +140,8 @@ class EmailPusher(object):
|
||||||
being run.
|
being run.
|
||||||
"""
|
"""
|
||||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||||
self.user_id, start, self.max_stream_ordering
|
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
|
||||||
)
|
|
||||||
|
|
||||||
soonest_due_at = None
|
soonest_due_at = None
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,8 @@ class HttpPusher(object):
|
||||||
run once per pusher.
|
run once per pusher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||||
|
unprocessed = yield fn(
|
||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ def get_context_for_event(state_handler, ev, user_id):
|
||||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||||
|
|
||||||
# we no longer bother setting room_alias, and make room_name the
|
# we no longer bother setting room_alias, and make room_name the
|
||||||
# human-readable name instead, be that m.room.namer, an alias or
|
# human-readable name instead, be that m.room.name, an alias or
|
||||||
# a list of people in the room
|
# a list of people in the room
|
||||||
name = calculate_room_name(
|
name = calculate_room_name(
|
||||||
room_state, user_id, fallback_to_single_member=False
|
room_state, user_id, fallback_to_single_member=False
|
||||||
|
|
|
@ -93,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||||
)
|
)
|
||||||
|
|
||||||
get_unread_push_actions_for_user_in_range = (
|
get_unread_push_actions_for_user_in_range_for_http = (
|
||||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||||
|
)
|
||||||
|
get_unread_push_actions_for_user_in_range_for_email = (
|
||||||
|
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
||||||
)
|
)
|
||||||
get_push_action_users_in_range = (
|
get_push_action_users_in_range = (
|
||||||
DataStore.get_push_action_users_in_range.__func__
|
DataStore.get_push_action_users_in_range.__func__
|
||||||
|
|
|
@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
device_id = yield self._register_device(user_id, login_submission)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
access_token, refresh_token = (
|
access_token, refresh_token = (
|
||||||
yield auth_handler.get_login_tuple_for_user_id(
|
yield auth_handler.get_login_tuple_for_user_id(
|
||||||
registered_user_id, device_id
|
registered_user_id, device_id,
|
||||||
|
login_submission.get("initial_device_display_name")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
|
|
|
@ -13,19 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http import servlet
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DevicesRestServlet(RestServlet):
|
class DevicesRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
||||||
defer.returnValue((200, {"devices": devices}))
|
defer.returnValue((200, {"devices": devices}))
|
||||||
|
|
||||||
|
|
||||||
class DeviceRestServlet(RestServlet):
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||||
releases=[], v2_alpha=False)
|
releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, device))
|
defer.returnValue((200, device))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_DELETE(self, request, device_id):
|
||||||
|
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||||
|
# It allows the client to delete access tokens, which feels like a
|
||||||
|
# thing which merits extra auth. But if we want to do the interactive-
|
||||||
|
# auth dance, we should really make it possible to delete more than one
|
||||||
|
# device at a time.
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
yield self.device_handler.delete_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
yield self.device_handler.update_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
body
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
DevicesRestServlet(hs).register(http_server)
|
DevicesRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -13,24 +13,25 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
from canonicaljson import encode_canonical_json
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
|
import synapse.server
|
||||||
|
import synapse.types
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
import logging
|
|
||||||
import simplejson as json
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KeyUploadServlet(RestServlet):
|
class KeyUploadServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
POST /keys/upload/<device_id> HTTP/1.1
|
POST /keys/upload HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
|
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||||
|
releases=())
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
super(KeyUploadServlet, self).__init__()
|
super(KeyUploadServlet, self).__init__()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
# TODO: Check that the device_id matches that in the authentication
|
|
||||||
# or derive the device_id from the authentication instead.
|
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if device_id is not None:
|
||||||
|
# passing the device_id here is deprecated; however, we allow it
|
||||||
|
# for now for compatibility with older clients.
|
||||||
|
if (requester.device_id is not None and
|
||||||
|
device_id != requester.device_id):
|
||||||
|
logger.warning("Client uploading keys for a different device "
|
||||||
|
"(logged in as %s, uploading for %s)",
|
||||||
|
requester.device_id, device_id)
|
||||||
|
else:
|
||||||
|
device_id = requester.device_id
|
||||||
|
|
||||||
|
if device_id is None:
|
||||||
|
raise synapse.api.errors.SynapseError(
|
||||||
|
400,
|
||||||
|
"To upload keys, you must pass device_id when authenticating"
|
||||||
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
# TODO: Validate the JSON to make sure it has the right keys.
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
@ -102,13 +125,14 @@ class KeyUploadServlet(RestServlet):
|
||||||
user_id, device_id, time_now, key_list
|
user_id, device_id, time_now, key_list
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
# the device should have been registered already, but it may have been
|
||||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
# deleted due to a race with a DELETE request. Or we may be using an
|
||||||
|
# old access_token without an associated device_id. Either way, we
|
||||||
@defer.inlineCallbacks
|
# need to double-check the device is registered to avoid ending up with
|
||||||
def on_GET(self, request, device_id):
|
# keys without a corresponding device.
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
self.device_handler.check_device_registered(
|
||||||
user_id = requester.user.to_string()
|
user_id, device_id, "unknown device"
|
||||||
|
)
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
|
|
@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
|
||||||
"""
|
"""
|
||||||
device_id = yield self._register_device(user_id, params)
|
device_id = yield self._register_device(user_id, params)
|
||||||
|
|
||||||
access_token = yield self.auth_handler.issue_access_token(
|
access_token, refresh_token = (
|
||||||
user_id, device_id=device_id
|
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
initial_display_name=params.get("initial_device_display_name")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
|
||||||
user_id, device_id=device_id
|
|
||||||
)
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
|
|
@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
return (200, {
|
return (200, {
|
||||||
"versions": ["r0.0.1"]
|
"versions": [
|
||||||
|
"r0.0.1",
|
||||||
|
"r0.1.0",
|
||||||
|
"r0.2.0",
|
||||||
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from . import engines
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start_doing_background_updates(self):
|
def start_doing_background_updates(self):
|
||||||
while True:
|
assert self._background_update_timer is None, \
|
||||||
if self._background_update_timer is not None:
|
"background updates already running"
|
||||||
return
|
|
||||||
|
|
||||||
|
logger.info("Starting background schema updates")
|
||||||
|
|
||||||
|
while True:
|
||||||
sleep = defer.Deferred()
|
sleep = defer.Deferred()
|
||||||
self._background_update_timer = self._clock.call_later(
|
self._background_update_timer = self._clock.call_later(
|
||||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
||||||
|
@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_timer = None
|
self._background_update_timer = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = yield self.do_background_update(
|
result = yield self.do_next_background_update(
|
||||||
self.BACKGROUND_UPDATE_DURATION_MS
|
self.BACKGROUND_UPDATE_DURATION_MS
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception("Error doing update")
|
logger.exception("Error doing update")
|
||||||
|
else:
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"No more background updates to do."
|
"No more background updates to do."
|
||||||
" Unscheduling background update task."
|
" Unscheduling background update task."
|
||||||
)
|
)
|
||||||
return
|
defer.returnValue(None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_background_update(self, desired_duration_ms):
|
def do_next_background_update(self, desired_duration_ms):
|
||||||
"""Does some amount of work on a background update
|
"""Does some amount of work on the next queued background update
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
desired_duration_ms(float): How long we want to spend
|
desired_duration_ms(float): How long we want to spend
|
||||||
updating.
|
updating.
|
||||||
|
@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_queue.append(update['update_name'])
|
self._background_update_queue.append(update['update_name'])
|
||||||
|
|
||||||
if not self._background_update_queue:
|
if not self._background_update_queue:
|
||||||
|
# no work left to do
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
# pop from the front, and add back to the back
|
||||||
update_name = self._background_update_queue.pop(0)
|
update_name = self._background_update_queue.pop(0)
|
||||||
self._background_update_queue.append(update_name)
|
self._background_update_queue.append(update_name)
|
||||||
|
|
||||||
|
res = yield self._do_background_update(update_name, desired_duration_ms)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_background_update(self, update_name, desired_duration_ms):
|
||||||
|
logger.info("Starting update batch on background update '%s'",
|
||||||
|
update_name)
|
||||||
|
|
||||||
update_handler = self._background_update_handlers[update_name]
|
update_handler = self._background_update_handlers[update_name]
|
||||||
|
|
||||||
performance = self._background_update_performance.get(update_name)
|
performance = self._background_update_performance.get(update_name)
|
||||||
|
@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
self._background_update_handlers[update_name] = update_handler
|
self._background_update_handlers[update_name] = update_handler
|
||||||
|
|
||||||
|
def register_background_index_update(self, update_name, index_name,
|
||||||
|
table, columns):
|
||||||
|
"""Helper for store classes to do a background index addition
|
||||||
|
|
||||||
|
To use:
|
||||||
|
|
||||||
|
1. use a schema delta file to add a background update. Example:
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('my_new_index', '{}');
|
||||||
|
|
||||||
|
2. In the Store constructor, call this method
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_name (str): update_name to register for
|
||||||
|
index_name (str): name of index to add
|
||||||
|
table (str): table to add index to
|
||||||
|
columns (list[str]): columns/expressions to include in index
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if this is postgres, we add the indexes concurrently. Otherwise
|
||||||
|
# we fall back to doing it inline
|
||||||
|
if isinstance(self.database_engine, engines.PostgresEngine):
|
||||||
|
conc = True
|
||||||
|
else:
|
||||||
|
conc = False
|
||||||
|
|
||||||
|
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
|
||||||
|
% {
|
||||||
|
"conc": "CONCURRENTLY" if conc else "",
|
||||||
|
"name": index_name,
|
||||||
|
"table": table,
|
||||||
|
"columns": ", ".join(columns),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_index_concurrently(conn):
|
||||||
|
conn.rollback()
|
||||||
|
# postgres insists on autocommit for the index
|
||||||
|
conn.set_session(autocommit=True)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(sql)
|
||||||
|
conn.set_session(autocommit=False)
|
||||||
|
|
||||||
|
def create_index(conn):
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute(sql)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def updater(progress, batch_size):
|
||||||
|
logger.info("Adding index %s to %s", index_name, table)
|
||||||
|
if conc:
|
||||||
|
yield self.runWithConnection(create_index_concurrently)
|
||||||
|
else:
|
||||||
|
yield self.runWithConnection(create_index)
|
||||||
|
yield self._end_background_update(update_name)
|
||||||
|
defer.returnValue(1)
|
||||||
|
|
||||||
|
self.register_background_update_handler(update_name, updater)
|
||||||
|
|
||||||
def start_background_update(self, update_name, progress):
|
def start_background_update(self, update_name, progress):
|
||||||
"""Starts a background update running.
|
"""Starts a background update running.
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,11 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ._base import SQLBaseStore, Cache
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._base import Cache
|
||||||
|
from . import background_updates
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||||
|
@ -27,8 +28,7 @@ logger = logging.getLogger(__name__)
|
||||||
LAST_SEEN_GRANULARITY = 120 * 1000
|
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||||
|
|
||||||
|
|
||||||
class ClientIpStore(SQLBaseStore):
|
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.client_ip_last_seen = Cache(
|
self.client_ip_last_seen = Cache(
|
||||||
name="client_ip_last_seen",
|
name="client_ip_last_seen",
|
||||||
|
@ -37,6 +37,13 @@ class ClientIpStore(SQLBaseStore):
|
||||||
|
|
||||||
super(ClientIpStore, self).__init__(hs)
|
super(ClientIpStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"user_ips_device_index",
|
||||||
|
index_name="user_ips_device_id",
|
||||||
|
table="user_ips",
|
||||||
|
columns=["user_id", "device_id", "last_seen"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
|
|
|
@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
|
||||||
desc="get_device",
|
desc="get_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_device(self, user_id, device_id):
|
||||||
|
"""Delete a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to delete
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
return self._simple_delete_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_device(self, user_id, device_id, new_display_name=None):
|
||||||
|
"""Update a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to update
|
||||||
|
new_display_name (str|None): new displayname for device; None
|
||||||
|
to leave unchanged
|
||||||
|
Raises:
|
||||||
|
StoreError: if the device is not found
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
updates = {}
|
||||||
|
if new_display_name is not None:
|
||||||
|
updates["display_name"] = new_display_name
|
||||||
|
if not updates:
|
||||||
|
return defer.succeed(None)
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
updatevalues=updates,
|
||||||
|
desc="update_device",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_devices_by_user(self, user_id):
|
def get_devices_by_user(self, user_id):
|
||||||
"""Retrieve all of a user's registered devices.
|
"""Retrieve all of a user's registered devices.
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import twisted.internet.defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@twisted.internet.defer.inlineCallbacks
|
||||||
|
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_e2e_device_keys_by_device"
|
||||||
|
)
|
||||||
|
yield self._simple_delete(
|
||||||
|
table="e2e_one_time_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
desc="delete_e2e_one_time_keys_by_device"
|
||||||
|
)
|
||||||
|
|
|
@ -117,21 +117,149 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_unread_push_actions_for_user_in_range(self, user_id,
|
def get_unread_push_actions_for_user_in_range_for_http(
|
||||||
min_stream_ordering,
|
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||||
max_stream_ordering=None,
|
):
|
||||||
limit=20):
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
|
within the given stream ordering range. Called by the httppusher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user to fetch push actions for.
|
||||||
|
min_stream_ordering(int): The exclusive lower bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
max_stream_ordering(int): The inclusive upper bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
limit (int): The maximum number of rows to return.
|
||||||
|
Returns:
|
||||||
|
A promise which resolves to a list of dicts with the keys "event_id",
|
||||||
|
"room_id", "stream_ordering", "actions".
|
||||||
|
The list will be ordered by ascending stream_ordering.
|
||||||
|
The list will have between 0~limit entries.
|
||||||
|
"""
|
||||||
|
# find rooms that have a read receipt in them and return the next
|
||||||
|
# push actions
|
||||||
|
def get_after_receipt(txn):
|
||||||
|
# find rooms that have a read receipt in them and return the next
|
||||||
|
# push actions
|
||||||
|
sql = (
|
||||||
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
|
||||||
|
" FROM ("
|
||||||
|
" SELECT room_id,"
|
||||||
|
" MAX(topological_ordering) as topological_ordering,"
|
||||||
|
" MAX(stream_ordering) as stream_ordering"
|
||||||
|
" FROM events"
|
||||||
|
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||||
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
|
") AS rl,"
|
||||||
|
" event_push_actions AS ep"
|
||||||
|
" WHERE"
|
||||||
|
" ep.room_id = rl.room_id"
|
||||||
|
" AND ("
|
||||||
|
" ep.topological_ordering > rl.topological_ordering"
|
||||||
|
" OR ("
|
||||||
|
" ep.topological_ordering = rl.topological_ordering"
|
||||||
|
" AND ep.stream_ordering > rl.stream_ordering"
|
||||||
|
" )"
|
||||||
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
user_id, user_id,
|
||||||
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
|
]
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return txn.fetchall()
|
||||||
|
after_read_receipt = yield self.runInteraction(
|
||||||
|
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||||
|
)
|
||||||
|
|
||||||
|
# There are rooms with push actions in them but you don't have a read receipt in
|
||||||
|
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||||
|
# not have read receipts in them too.
|
||||||
|
def get_no_receipt(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
|
" e.received_ts"
|
||||||
|
" FROM event_push_actions AS ep"
|
||||||
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
|
" WHERE"
|
||||||
|
" ep.room_id NOT IN ("
|
||||||
|
" SELECT room_id FROM receipts_linearized"
|
||||||
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
args = [
|
||||||
|
user_id, user_id,
|
||||||
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
|
]
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return txn.fetchall()
|
||||||
|
no_read_receipt = yield self.runInteraction(
|
||||||
|
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||||
|
)
|
||||||
|
|
||||||
|
notifs = [
|
||||||
|
{
|
||||||
|
"event_id": row[0],
|
||||||
|
"room_id": row[1],
|
||||||
|
"stream_ordering": row[2],
|
||||||
|
"actions": json.loads(row[3]),
|
||||||
|
} for row in after_read_receipt + no_read_receipt
|
||||||
|
]
|
||||||
|
|
||||||
|
# Now sort it so it's ordered correctly, since currently it will
|
||||||
|
# contain results from the first query, correctly ordered, followed
|
||||||
|
# by results from the second query, but we want them all ordered
|
||||||
|
# by stream_ordering, oldest first.
|
||||||
|
notifs.sort(key=lambda r: r['stream_ordering'])
|
||||||
|
|
||||||
|
# Take only up to the limit. We have to stop at the limit because
|
||||||
|
# one of the subqueries may have hit the limit.
|
||||||
|
defer.returnValue(notifs[:limit])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_unread_push_actions_for_user_in_range_for_email(
|
||||||
|
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||||
|
):
|
||||||
|
"""Get a list of the most recent unread push actions for a given user,
|
||||||
|
within the given stream ordering range. Called by the emailpusher
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user to fetch push actions for.
|
||||||
|
min_stream_ordering(int): The exclusive lower bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
max_stream_ordering(int): The inclusive upper bound on the
|
||||||
|
stream ordering of event push actions to fetch.
|
||||||
|
limit (int): The maximum number of rows to return.
|
||||||
|
Returns:
|
||||||
|
A promise which resolves to a list of dicts with the keys "event_id",
|
||||||
|
"room_id", "stream_ordering", "actions", "received_ts".
|
||||||
|
The list will be ordered by descending received_ts.
|
||||||
|
The list will have between 0~limit entries.
|
||||||
|
"""
|
||||||
|
# find rooms that have a read receipt in them and return the most recent
|
||||||
|
# push actions
|
||||||
def get_after_receipt(txn):
|
def get_after_receipt(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
" e.received_ts"
|
" e.received_ts"
|
||||||
" FROM ("
|
" FROM ("
|
||||||
" SELECT room_id, user_id, "
|
" SELECT room_id,"
|
||||||
" max(topological_ordering) as topological_ordering, "
|
" MAX(topological_ordering) as topological_ordering,"
|
||||||
" max(stream_ordering) as stream_ordering "
|
" MAX(stream_ordering) as stream_ordering"
|
||||||
" FROM events"
|
" FROM events"
|
||||||
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
|
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||||
" GROUP BY room_id, user_id"
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
|
" GROUP BY room_id"
|
||||||
") AS rl,"
|
") AS rl,"
|
||||||
" event_push_actions AS ep"
|
" event_push_actions AS ep"
|
||||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
|
@ -144,44 +272,49 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
" AND ep.stream_ordering > rl.stream_ordering"
|
" AND ep.stream_ordering > rl.stream_ordering"
|
||||||
" )"
|
" )"
|
||||||
" )"
|
" )"
|
||||||
" AND ep.stream_ordering > ?"
|
|
||||||
" AND ep.user_id = ?"
|
" AND ep.user_id = ?"
|
||||||
" AND ep.user_id = rl.user_id"
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
args = [min_stream_ordering, user_id]
|
args = [
|
||||||
if max_stream_ordering is not None:
|
user_id, user_id,
|
||||||
sql += " AND ep.stream_ordering <= ?"
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
args.append(max_stream_ordering)
|
]
|
||||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
|
||||||
args.append(limit)
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
after_read_receipt = yield self.runInteraction(
|
after_read_receipt = yield self.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range", get_after_receipt
|
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# There are rooms with push actions in them but you don't have a read receipt in
|
||||||
|
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||||
|
# not have read receipts in them too.
|
||||||
def get_no_receipt(txn):
|
def get_no_receipt(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||||
" e.received_ts"
|
" e.received_ts"
|
||||||
" FROM event_push_actions AS ep"
|
" FROM event_push_actions AS ep"
|
||||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||||
" WHERE ep.room_id not in ("
|
" WHERE"
|
||||||
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
|
" ep.room_id NOT IN ("
|
||||||
|
" SELECT room_id FROM receipts_linearized"
|
||||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||||
" GROUP BY room_id"
|
" GROUP BY room_id"
|
||||||
") AND ep.user_id = ? AND ep.stream_ordering > ?"
|
" )"
|
||||||
|
" AND ep.user_id = ?"
|
||||||
|
" AND ep.stream_ordering > ?"
|
||||||
|
" AND ep.stream_ordering <= ?"
|
||||||
|
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
args = [user_id, user_id, min_stream_ordering]
|
args = [
|
||||||
if max_stream_ordering is not None:
|
user_id, user_id,
|
||||||
sql += " AND ep.stream_ordering <= ?"
|
min_stream_ordering, max_stream_ordering, limit,
|
||||||
args.append(max_stream_ordering)
|
]
|
||||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
|
||||||
args.append(limit)
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
no_read_receipt = yield self.runInteraction(
|
no_read_receipt = yield self.runInteraction(
|
||||||
"get_unread_push_actions_for_user_in_range", get_no_receipt
|
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make a list of dicts from the two sets of results.
|
# Make a list of dicts from the two sets of results.
|
||||||
|
@ -198,7 +331,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
# Now sort it so it's ordered correctly, since currently it will
|
# Now sort it so it's ordered correctly, since currently it will
|
||||||
# contain results from the first query, correctly ordered, followed
|
# contain results from the first query, correctly ordered, followed
|
||||||
# by results from the second query, but we want them all ordered
|
# by results from the second query, but we want them all ordered
|
||||||
# by received_ts
|
# by received_ts (most recent first)
|
||||||
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
|
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
|
||||||
|
|
||||||
# Now return the first `limit`
|
# Now return the first `limit`
|
||||||
|
|
|
@ -397,6 +397,12 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
|
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
|
||||||
|
"""Insert some number of room events into the necessary database tables.
|
||||||
|
|
||||||
|
Rejected events are only inserted into the events table, the events_json table,
|
||||||
|
and the rejections table. Things reading from those table will need to check
|
||||||
|
whether the event was rejected.
|
||||||
|
"""
|
||||||
depth_updates = {}
|
depth_updates = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
|
@ -407,21 +413,11 @@ class EventsStore(SQLBaseStore):
|
||||||
event.room_id, event.internal_metadata.stream_ordering,
|
event.room_id, event.internal_metadata.stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not event.internal_metadata.is_outlier():
|
if not event.internal_metadata.is_outlier() and not context.rejected:
|
||||||
depth_updates[event.room_id] = max(
|
depth_updates[event.room_id] = max(
|
||||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.push_actions:
|
|
||||||
self._set_push_actions_for_event_and_users_txn(
|
|
||||||
txn, event, context.push_actions
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
|
||||||
self._remove_push_actions_for_event_id_txn(
|
|
||||||
txn, event.room_id, event.redacts
|
|
||||||
)
|
|
||||||
|
|
||||||
for room_id, depth in depth_updates.items():
|
for room_id, depth in depth_updates.items():
|
||||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||||
|
|
||||||
|
@ -431,14 +427,24 @@ class EventsStore(SQLBaseStore):
|
||||||
),
|
),
|
||||||
[event.event_id for event, _ in events_and_contexts]
|
[event.event_id for event, _ in events_and_contexts]
|
||||||
)
|
)
|
||||||
|
|
||||||
have_persisted = {
|
have_persisted = {
|
||||||
event_id: outlier
|
event_id: outlier
|
||||||
for event_id, outlier in txn.fetchall()
|
for event_id, outlier in txn.fetchall()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Remove the events that we've seen before.
|
||||||
event_map = {}
|
event_map = {}
|
||||||
to_remove = set()
|
to_remove = set()
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
|
if context.rejected:
|
||||||
|
# If the event is rejected then we don't care if the event
|
||||||
|
# was an outlier or not.
|
||||||
|
if event.event_id in have_persisted:
|
||||||
|
# If we have already seen the event then ignore it.
|
||||||
|
to_remove.add(event)
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle the case of the list including the same event multiple
|
# Handle the case of the list including the same event multiple
|
||||||
# times. The tricky thing here is when they differ by whether
|
# times. The tricky thing here is when they differ by whether
|
||||||
# they are an outlier.
|
# they are an outlier.
|
||||||
|
@ -463,6 +469,12 @@ class EventsStore(SQLBaseStore):
|
||||||
|
|
||||||
outlier_persisted = have_persisted[event.event_id]
|
outlier_persisted = have_persisted[event.event_id]
|
||||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||||
|
# We received a copy of an event that we had already stored as
|
||||||
|
# an outlier in the database. We now have some state at that
|
||||||
|
# so we need to update the state_groups table with that state.
|
||||||
|
|
||||||
|
# insert into the state_group, state_groups_state and
|
||||||
|
# event_to_state_groups tables.
|
||||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
self._store_mult_state_groups_txn(txn, ((event, context),))
|
||||||
|
|
||||||
metadata_json = encode_json(
|
metadata_json = encode_json(
|
||||||
|
@ -478,6 +490,8 @@ class EventsStore(SQLBaseStore):
|
||||||
(metadata_json, event.event_id,)
|
(metadata_json, event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add an entry to the ex_outlier_stream table to replicate the
|
||||||
|
# change in outlier status to our workers.
|
||||||
stream_order = event.internal_metadata.stream_ordering
|
stream_order = event.internal_metadata.stream_ordering
|
||||||
state_group_id = context.state_group or context.new_state_group_id
|
state_group_id = context.state_group or context.new_state_group_id
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
@ -499,6 +513,8 @@ class EventsStore(SQLBaseStore):
|
||||||
(False, event.event_id,)
|
(False, event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the event_backward_extremities table now that this
|
||||||
|
# event isn't an outlier any more.
|
||||||
self._update_extremeties(txn, [event])
|
self._update_extremeties(txn, [event])
|
||||||
|
|
||||||
events_and_contexts = [
|
events_and_contexts = [
|
||||||
|
@ -506,38 +522,12 @@ class EventsStore(SQLBaseStore):
|
||||||
]
|
]
|
||||||
|
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
|
# Make sure we don't pass an empty list to functions that expect to
|
||||||
|
# be storing at least one element.
|
||||||
return
|
return
|
||||||
|
|
||||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
# From this point onwards the events are only events that we haven't
|
||||||
|
# seen before.
|
||||||
self._handle_mult_prev_events(
|
|
||||||
txn,
|
|
||||||
events=[event for event, _ in events_and_contexts],
|
|
||||||
)
|
|
||||||
|
|
||||||
for event, _ in events_and_contexts:
|
|
||||||
if event.type == EventTypes.Name:
|
|
||||||
self._store_room_name_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Topic:
|
|
||||||
self._store_room_topic_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Message:
|
|
||||||
self._store_room_message_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.Redaction:
|
|
||||||
self._store_redaction(txn, event)
|
|
||||||
elif event.type == EventTypes.RoomHistoryVisibility:
|
|
||||||
self._store_history_visibility_txn(txn, event)
|
|
||||||
elif event.type == EventTypes.GuestAccess:
|
|
||||||
self._store_guest_access_txn(txn, event)
|
|
||||||
|
|
||||||
self._store_room_members_txn(
|
|
||||||
txn,
|
|
||||||
[
|
|
||||||
event
|
|
||||||
for event, _ in events_and_contexts
|
|
||||||
if event.type == EventTypes.Member
|
|
||||||
],
|
|
||||||
backfilled=backfilled,
|
|
||||||
)
|
|
||||||
|
|
||||||
def event_dict(event):
|
def event_dict(event):
|
||||||
return {
|
return {
|
||||||
|
@ -591,10 +581,41 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove the rejected events from the list now that we've added them
|
||||||
|
# to the events table and the events_json table.
|
||||||
|
to_remove = set()
|
||||||
|
for event, context in events_and_contexts:
|
||||||
if context.rejected:
|
if context.rejected:
|
||||||
|
# Insert the event_id into the rejections table
|
||||||
self._store_rejections_txn(
|
self._store_rejections_txn(
|
||||||
txn, event.event_id, context.rejected
|
txn, event.event_id, context.rejected
|
||||||
)
|
)
|
||||||
|
to_remove.add(event)
|
||||||
|
|
||||||
|
events_and_contexts = [
|
||||||
|
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
if not events_and_contexts:
|
||||||
|
# Make sure we don't pass an empty list to functions that expect to
|
||||||
|
# be storing at least one element.
|
||||||
|
return
|
||||||
|
|
||||||
|
# From this point onwards the events are only ones that weren't rejected.
|
||||||
|
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
# Insert all the push actions into the event_push_actions table.
|
||||||
|
if context.push_actions:
|
||||||
|
self._set_push_actions_for_event_and_users_txn(
|
||||||
|
txn, event, context.push_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||||
|
# Remove the entries in the event_push_actions table for the
|
||||||
|
# redacted event.
|
||||||
|
self._remove_push_actions_for_event_id_txn(
|
||||||
|
txn, event.room_id, event.redacts
|
||||||
|
)
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -610,6 +631,49 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Insert into the state_groups, state_groups_state, and
|
||||||
|
# event_to_state_groups tables.
|
||||||
|
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||||
|
|
||||||
|
# Update the event_forward_extremities, event_backward_extremities and
|
||||||
|
# event_edges tables.
|
||||||
|
self._handle_mult_prev_events(
|
||||||
|
txn,
|
||||||
|
events=[event for event, _ in events_and_contexts],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event, _ in events_and_contexts:
|
||||||
|
if event.type == EventTypes.Name:
|
||||||
|
# Insert into the room_names and event_search tables.
|
||||||
|
self._store_room_name_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Topic:
|
||||||
|
# Insert into the topics table and event_search table.
|
||||||
|
self._store_room_topic_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Message:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_room_message_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.Redaction:
|
||||||
|
# Insert into the redactions table.
|
||||||
|
self._store_redaction(txn, event)
|
||||||
|
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_history_visibility_txn(txn, event)
|
||||||
|
elif event.type == EventTypes.GuestAccess:
|
||||||
|
# Insert into the event_search table.
|
||||||
|
self._store_guest_access_txn(txn, event)
|
||||||
|
|
||||||
|
# Insert into the room_memberships table.
|
||||||
|
self._store_room_members_txn(
|
||||||
|
txn,
|
||||||
|
[
|
||||||
|
event
|
||||||
|
for event, _ in events_and_contexts
|
||||||
|
if event.type == EventTypes.Member
|
||||||
|
],
|
||||||
|
backfilled=backfilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert event_reference_hashes table.
|
||||||
self._store_event_reference_hashes_txn(
|
self._store_event_reference_hashes_txn(
|
||||||
txn, [event for event, _ in events_and_contexts]
|
txn, [event for event, _ in events_and_contexts]
|
||||||
)
|
)
|
||||||
|
@ -654,6 +718,7 @@ class EventsStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefill the event cache
|
||||||
self._add_to_cache(txn, events_and_contexts)
|
self._add_to_cache(txn, events_and_contexts)
|
||||||
|
|
||||||
if backfilled:
|
if backfilled:
|
||||||
|
@ -666,11 +731,6 @@ class EventsStore(SQLBaseStore):
|
||||||
# Outlier events shouldn't clobber the current state.
|
# Outlier events shouldn't clobber the current state.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.rejected:
|
|
||||||
# If the event failed it's auth checks then it shouldn't
|
|
||||||
# clobbler the current state.
|
|
||||||
continue
|
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._get_current_state_for_key.invalidate,
|
self._get_current_state_for_key.invalidate,
|
||||||
(event.room_id, event.type, event.state_key,)
|
(event.room_id, event.type, event.state_key,)
|
||||||
|
|
|
@ -18,18 +18,31 @@ import re
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
from synapse.storage import background_updates
|
||||||
from ._base import SQLBaseStore
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RegistrationStore, self).__init__(hs)
|
super(RegistrationStore, self).__init__(hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"access_tokens_device_index",
|
||||||
|
index_name="access_tokens_device_id",
|
||||||
|
table="access_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
"refresh_tokens_device_index",
|
||||||
|
index_name="refresh_tokens_device_id",
|
||||||
|
table="refresh_tokens",
|
||||||
|
columns=["user_id", "device_id"],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
@ -238,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
self.get_user_by_id.invalidate((user_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
|
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||||
def f(txn):
|
device_id=None,
|
||||||
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
delete_refresh_tokens=False):
|
||||||
|
"""
|
||||||
|
Invalidate access/refresh tokens belonging to a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of user the tokens belong to
|
||||||
|
except_token_ids (list[str]): list of access_tokens which should
|
||||||
|
*not* be deleted
|
||||||
|
device_id (str|None): ID of device the tokens are associated with.
|
||||||
|
If None, tokens associated with any device (or no device) will
|
||||||
|
be deleted
|
||||||
|
delete_refresh_tokens (bool): True to delete refresh tokens as
|
||||||
|
well as access tokens.
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
def f(txn, table, except_tokens, call_after_delete):
|
||||||
|
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
||||||
clauses = [user_id]
|
clauses = [user_id]
|
||||||
|
|
||||||
if except_token_ids:
|
if device_id is not None:
|
||||||
|
sql += " AND device_id = ?"
|
||||||
|
clauses.append(device_id)
|
||||||
|
|
||||||
|
if except_tokens:
|
||||||
sql += " AND id NOT IN (%s)" % (
|
sql += " AND id NOT IN (%s)" % (
|
||||||
",".join(["?" for _ in except_token_ids]),
|
",".join(["?" for _ in except_tokens]),
|
||||||
)
|
)
|
||||||
clauses += except_token_ids
|
clauses += except_tokens
|
||||||
|
|
||||||
txn.execute(sql, clauses)
|
txn.execute(sql, clauses)
|
||||||
|
|
||||||
|
@ -256,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
|
||||||
n = 100
|
n = 100
|
||||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
if call_after_delete:
|
||||||
for row in chunk:
|
for row in chunk:
|
||||||
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
|
txn.call_after(call_after_delete, (row[0],))
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"DELETE FROM access_tokens WHERE token in (%s)" % (
|
"DELETE FROM %s WHERE token in (%s)" % (
|
||||||
|
table,
|
||||||
",".join(["?" for _ in chunk]),
|
",".join(["?" for _ in chunk]),
|
||||||
), [r[0] for r in chunk]
|
), [r[0] for r in chunk]
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.runInteraction("user_delete_access_tokens", f)
|
# delete refresh tokens first, to stop new access tokens being
|
||||||
|
# allocated while our backs are turned
|
||||||
|
if delete_refresh_tokens:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"user_delete_access_tokens", f,
|
||||||
|
table="refresh_tokens",
|
||||||
|
except_tokens=[],
|
||||||
|
call_after_delete=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.runInteraction(
|
||||||
|
"user_delete_access_tokens", f,
|
||||||
|
table="access_tokens",
|
||||||
|
except_tokens=except_token_ids,
|
||||||
|
call_after_delete=self.get_user_by_access_token.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -288,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
Args:
|
Args:
|
||||||
token (str): The access token of a user.
|
token (str): The access token of a user.
|
||||||
Returns:
|
Returns:
|
||||||
dict: Including the name (user_id) and the ID of their access token.
|
defer.Deferred: None, if the token did not match, otherwise dict
|
||||||
Raises:
|
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||||
StoreError if no user was found.
|
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_user_by_access_token",
|
"get_user_by_access_token",
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('access_tokens_device_index', '{}');
|
|
@ -0,0 +1,19 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- make sure that we have a device record for each set of E2E keys, so that the
|
||||||
|
-- user can delete them if they like.
|
||||||
|
INSERT INTO devices
|
||||||
|
SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json;
|
|
@ -0,0 +1,17 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('refresh_tokens_device_index', '{}');
|
|
@ -13,4 +13,5 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen);
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('user_ips_device_index', '{}');
|
||||||
|
|
|
@ -24,6 +24,7 @@ from collections import namedtuple
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if result and result["response_code"]:
|
if result and result["response_code"]:
|
||||||
return result["response_code"], result["response_json"]
|
return result["response_code"], json.loads(str(result["response_json"]))
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
Requester = namedtuple("Requester",
|
||||||
|
["user", "access_token_id", "is_guest", "device_id"])
|
||||||
|
"""
|
||||||
|
Represents the user making a request
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
user (UserID): id of the user making the request
|
||||||
|
access_token_id (int|None): *ID* of the access token used for this
|
||||||
|
request, or None if it came via the appservice API or similar
|
||||||
|
is_guest (bool): True if the user making this request is a guest user
|
||||||
|
device_id (str|None): device_id which was set at authentication time
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_requester(user_id, access_token_id=None, is_guest=False,
|
||||||
|
device_id=None):
|
||||||
|
"""
|
||||||
|
Create a new ``Requester`` object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str|UserID): id of the user making the request
|
||||||
|
access_token_id (int|None): *ID* of the access token used for this
|
||||||
|
request, or None if it came via the appservice API or similar
|
||||||
|
is_guest (bool): True if the user making this request is a guest user
|
||||||
|
device_id (str|None): device_id which was set at authentication time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Requester
|
||||||
|
"""
|
||||||
|
if not isinstance(user_id, UserID):
|
||||||
|
user_id = UserID.from_string(user_id)
|
||||||
|
return Requester(user_id, access_token_id, is_guest, device_id)
|
||||||
|
|
||||||
|
|
||||||
def get_domain_from_id(string):
|
def get_domain_from_id(string):
|
||||||
|
|
|
@ -84,7 +84,7 @@ class Measure(object):
|
||||||
|
|
||||||
if context != self.start_context:
|
if context != self.start_context:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
|
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
|
||||||
context, self.start_context, self.name
|
context, self.start_context, self.name
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
|
@ -83,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||||
):
|
):
|
||||||
if ("m.room.member", my_member_event.sender) in room_state:
|
if ("m.room.member", my_member_event.sender) in room_state:
|
||||||
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
|
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
|
||||||
|
if fallback_to_single_member:
|
||||||
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
|
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return "Room Invite"
|
return "Room Invite"
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,7 @@ class RetryDestinationLimiter(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_err_code = False
|
valid_err_code = False
|
||||||
if exc_type is CodeMessageException:
|
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
||||||
valid_err_code = 0 <= exc_val.code < 500
|
valid_err_code = 0 <= exc_val.code < 500
|
||||||
|
|
||||||
if exc_type is None or valid_err_code:
|
if exc_type is None or valid_err_code:
|
||||||
|
|
|
@ -12,11 +12,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse import types
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
import synapse.handlers.device
|
import synapse.handlers.device
|
||||||
|
|
||||||
import synapse.storage
|
import synapse.storage
|
||||||
|
from synapse import types
|
||||||
from tests import unittest, utils
|
from tests import unittest, utils
|
||||||
|
|
||||||
user1 = "@boris:aaa"
|
user1 = "@boris:aaa"
|
||||||
|
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||||
self.store = None # type: synapse.storage.DataStore
|
self.store = None # type: synapse.storage.DataStore
|
||||||
self.handler = None # type: device.DeviceHandler
|
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||||
self.clock = None # type: utils.MockClock
|
self.clock = None # type: utils.MockClock
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -84,28 +87,31 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
yield self._record_users()
|
yield self._record_users()
|
||||||
|
|
||||||
res = yield self.handler.get_devices_by_user(user1)
|
res = yield self.handler.get_devices_by_user(user1)
|
||||||
self.assertEqual(3, len(res.keys()))
|
self.assertEqual(3, len(res))
|
||||||
|
device_map = {
|
||||||
|
d["device_id"]: d for d in res
|
||||||
|
}
|
||||||
self.assertDictContainsSubset({
|
self.assertDictContainsSubset({
|
||||||
"user_id": user1,
|
"user_id": user1,
|
||||||
"device_id": "xyz",
|
"device_id": "xyz",
|
||||||
"display_name": "display 0",
|
"display_name": "display 0",
|
||||||
"last_seen_ip": None,
|
"last_seen_ip": None,
|
||||||
"last_seen_ts": None,
|
"last_seen_ts": None,
|
||||||
}, res["xyz"])
|
}, device_map["xyz"])
|
||||||
self.assertDictContainsSubset({
|
self.assertDictContainsSubset({
|
||||||
"user_id": user1,
|
"user_id": user1,
|
||||||
"device_id": "fco",
|
"device_id": "fco",
|
||||||
"display_name": "display 1",
|
"display_name": "display 1",
|
||||||
"last_seen_ip": "ip1",
|
"last_seen_ip": "ip1",
|
||||||
"last_seen_ts": 1000000,
|
"last_seen_ts": 1000000,
|
||||||
}, res["fco"])
|
}, device_map["fco"])
|
||||||
self.assertDictContainsSubset({
|
self.assertDictContainsSubset({
|
||||||
"user_id": user1,
|
"user_id": user1,
|
||||||
"device_id": "abc",
|
"device_id": "abc",
|
||||||
"display_name": "display 2",
|
"display_name": "display 2",
|
||||||
"last_seen_ip": "ip3",
|
"last_seen_ip": "ip3",
|
||||||
"last_seen_ts": 3000000,
|
"last_seen_ts": 3000000,
|
||||||
}, res["abc"])
|
}, device_map["abc"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_device(self):
|
def test_get_device(self):
|
||||||
|
@ -120,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
"last_seen_ts": 3000000,
|
"last_seen_ts": 3000000,
|
||||||
}, res)
|
}, res)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_delete_device(self):
|
||||||
|
yield self._record_users()
|
||||||
|
|
||||||
|
# delete the device
|
||||||
|
yield self.handler.delete_device(user1, "abc")
|
||||||
|
|
||||||
|
# check the device was deleted
|
||||||
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
|
yield self.handler.get_device(user1, "abc")
|
||||||
|
|
||||||
|
# we'd like to check the access token was invalidated, but that's a
|
||||||
|
# bit of a PITA.
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self._record_users()
|
||||||
|
|
||||||
|
update = {"display_name": "new display"}
|
||||||
|
yield self.handler.update_device(user1, "abc", update)
|
||||||
|
|
||||||
|
res = yield self.handler.get_device(user1, "abc")
|
||||||
|
self.assertEqual(res["display_name"], "new display")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
update = {"display_name": "new_display"}
|
||||||
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
|
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||||
|
update)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _record_users(self):
|
def _record_users(self):
|
||||||
# check this works for both devices which have a recorded client_ip,
|
# check this works for both devices which have a recorded client_ip,
|
||||||
|
|
|
@ -19,11 +19,12 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock, NonCallableMock
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.handlers.profile import ProfileHandler
|
from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver, requester_for_user
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class ProfileHandlers(object):
|
class ProfileHandlers(object):
|
||||||
|
@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self):
|
||||||
yield self.handler.set_displayname(
|
yield self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank,
|
||||||
requester_for_user(self.frank),
|
synapse.types.create_requester(self.frank),
|
||||||
"Frank Jr."
|
"Frank Jr."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self):
|
||||||
d = self.handler.set_displayname(
|
d = self.handler.set_displayname(
|
||||||
self.frank,
|
self.frank,
|
||||||
requester_for_user(self.bob),
|
synapse.types.create_requester(self.bob),
|
||||||
"Frank Jr."
|
"Frank Jr."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self):
|
||||||
yield self.handler.set_avatar_url(
|
yield self.handler.set_avatar_url(
|
||||||
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
|
self.frank, synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/pic.gif"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
|
|
|
@ -13,15 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.resource import ReplicationResource
|
|
||||||
from synapse.types import Requester, UserID
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from tests import unittest
|
|
||||||
from tests.utils import setup_test_homeserver, requester_for_user
|
|
||||||
from mock import Mock, NonCallableMock
|
|
||||||
import json
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
from mock import Mock, NonCallableMock
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.types
|
||||||
|
from synapse.replication.resource import ReplicationResource
|
||||||
|
from synapse.types import UserID
|
||||||
|
from tests import unittest
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
class ReplicationResourceCase(unittest.TestCase):
|
class ReplicationResourceCase(unittest.TestCase):
|
||||||
|
@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
def test_events_and_state(self):
|
def test_events_and_state(self):
|
||||||
get = self.get(events="-1", state="-1", timeout="0")
|
get = self.get(events="-1", state="-1", timeout="0")
|
||||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
Requester(self.user, "", False), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
code, body = yield get
|
code, body = yield get
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
|
@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
def send_text_message(self, room_id, message):
|
def send_text_message(self, room_id, message):
|
||||||
handler = self.hs.get_handlers().message_handler
|
handler = self.hs.get_handlers().message_handler
|
||||||
event = yield handler.create_and_send_nonmember_event(
|
event = yield handler.create_and_send_nonmember_event(
|
||||||
requester_for_user(self.user),
|
synapse.types.create_requester(self.user),
|
||||||
{
|
{
|
||||||
"type": "m.room.message",
|
"type": "m.room.message",
|
||||||
"content": {"body": "message", "msgtype": "m.text"},
|
"content": {"body": "message", "msgtype": "m.text"},
|
||||||
|
@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room(self):
|
def create_room(self):
|
||||||
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
Requester(self.user, "", False), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
defer.returnValue(result["room_id"])
|
defer.returnValue(result["room_id"])
|
||||||
|
|
||||||
|
|
|
@ -14,17 +14,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /profile paths."""
|
"""Tests REST events for /profile paths."""
|
||||||
from tests import unittest
|
from mock import Mock
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock
|
import synapse.types
|
||||||
|
|
||||||
from ....utils import MockHttpResource, setup_test_homeserver
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.types import Requester, UserID
|
|
||||||
|
|
||||||
from synapse.rest.client.v1 import profile
|
from synapse.rest.client.v1 import profile
|
||||||
|
from tests import unittest
|
||||||
|
from ....utils import MockHttpResource, setup_test_homeserver
|
||||||
|
|
||||||
myid = "@1234ABCD:test"
|
myid = "@1234ABCD:test"
|
||||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||||
|
@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return Requester(UserID.from_string(myid), "", False)
|
return synapse.types.create_requester(myid)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
|
@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.registration_handler.appservice_register = Mock(
|
self.registration_handler.appservice_register = Mock(
|
||||||
return_value=user_id
|
return_value=user_id
|
||||||
)
|
)
|
||||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||||
|
return_value=(token, "kermits_refresh_token")
|
||||||
|
)
|
||||||
|
|
||||||
(code, result) = yield self.servlet.on_POST(self.request)
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
|
"refresh_token": "kermits_refresh_token",
|
||||||
"home_server": self.hs.hostname
|
"home_server": self.hs.hostname
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
|
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||||
|
return_value=(token, "kermits_refresh_token")
|
||||||
|
)
|
||||||
self.device_handler.check_device_registered = \
|
self.device_handler.check_device_registered = \
|
||||||
Mock(return_value=device_id)
|
Mock(return_value=device_id)
|
||||||
|
|
||||||
|
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
|
"refresh_token": "kermits_refresh_token",
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
self.assertIn("refresh_token", result)
|
self.assertIn("refresh_token", result)
|
||||||
self.auth_handler.issue_access_token.assert_called_once_with(
|
self.auth_handler.get_login_tuple_for_user_id(
|
||||||
user_id, device_id=device_id)
|
user_id, device_id=device_id, initial_device_display_name=None)
|
||||||
|
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self):
|
||||||
self.hs.config.enable_registration = False
|
self.hs.config.enable_registration = False
|
||||||
|
|
|
@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver()
|
hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
"test_update", self.update_handler
|
"test_update", self.update_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# run the real background updates, to get them out the way
|
||||||
|
# (perhaps we should run them as part of the test HS setup, since we
|
||||||
|
# run all of the other schema setup stuff there?)
|
||||||
|
while True:
|
||||||
|
res = yield self.store.do_next_background_update(1000)
|
||||||
|
if res is None:
|
||||||
|
break
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_do_background_update(self):
|
def test_do_background_update(self):
|
||||||
desired_count = 1000
|
desired_count = 1000
|
||||||
duration_ms = 42
|
duration_ms = 42
|
||||||
|
|
||||||
|
# first step: make a bit of progress
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update(progress, count):
|
def update(progress, count):
|
||||||
self.clock.advance_time_msec(count * duration_ms)
|
self.clock.advance_time_msec(count * duration_ms)
|
||||||
|
@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
yield self.store.start_background_update("test_update", {"my_key": 1})
|
yield self.store.start_background_update("test_update", {"my_key": 1})
|
||||||
|
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
result = yield self.store.do_background_update(
|
result = yield self.store.do_next_background_update(
|
||||||
duration_ms * desired_count
|
duration_ms * desired_count
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
|
@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
|
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# second step: complete the update
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update(progress, count):
|
def update(progress, count):
|
||||||
yield self.store._end_background_update("test_update")
|
yield self.store._end_background_update("test_update")
|
||||||
defer.returnValue(count)
|
defer.returnValue(count)
|
||||||
|
|
||||||
self.update_handler.side_effect = update
|
self.update_handler.side_effect = update
|
||||||
|
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
result = yield self.store.do_background_update(
|
result = yield self.store.do_next_background_update(
|
||||||
duration_ms * desired_count
|
duration_ms * desired_count
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
|
@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
{"my_key": 2}, desired_count
|
{"my_key": 2}, desired_count
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# third step: we don't expect to be called any more
|
||||||
self.update_handler.reset_mock()
|
self.update_handler.reset_mock()
|
||||||
result = yield self.store.do_background_update(
|
result = yield self.store.do_next_background_update(
|
||||||
duration_ms * desired_count
|
duration_ms * desired_count
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
import tests.utils
|
import tests.utils
|
||||||
|
|
||||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
"device_id": "device2",
|
"device_id": "device2",
|
||||||
"display_name": "display_name 2",
|
"display_name": "display_name 2",
|
||||||
}, res["device2"])
|
}, res["device2"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self.store.store_device(
|
||||||
|
"user_id", "device_id", "display_name 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do a no-op first
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
)
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do the update
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check it worked
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 2", res["display_name"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "unknown_device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
self.assertEqual(404, cm.exception.code)
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import tests.unittest
|
||||||
|
import tests.utils
|
||||||
|
|
||||||
|
USER_ID = "@user:example.com"
|
||||||
|
|
||||||
|
|
||||||
|
class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
hs = yield tests.utils.setup_test_homeserver()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_unread_push_actions_for_user_in_range_for_http(self):
|
||||||
|
yield self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||||
|
USER_ID, 0, 1000, 20
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_unread_push_actions_for_user_in_range_for_email(self):
|
||||||
|
yield self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||||
|
USER_ID, 0, 1000, 20
|
||||||
|
)
|
|
@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
with self.assertRaises(StoreError):
|
with self.assertRaises(StoreError):
|
||||||
yield self.store.exchange_refresh_token(last_token, generator.generate)
|
yield self.store.exchange_refresh_token(last_token, generator.generate)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_user_delete_access_tokens(self):
|
||||||
|
# add some tokens
|
||||||
|
generator = TokenGenerator()
|
||||||
|
refresh_token = generator.generate(self.user_id)
|
||||||
|
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
|
||||||
|
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
|
||||||
|
self.device_id)
|
||||||
|
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
|
||||||
|
self.device_id)
|
||||||
|
|
||||||
|
# now delete some
|
||||||
|
yield self.store.user_delete_access_tokens(
|
||||||
|
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
|
||||||
|
|
||||||
|
# check they were deleted
|
||||||
|
user = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||||
|
self.assertIsNone(user, "access token was not deleted by device_id")
|
||||||
|
with self.assertRaises(StoreError):
|
||||||
|
yield self.store.exchange_refresh_token(refresh_token,
|
||||||
|
generator.generate)
|
||||||
|
|
||||||
|
# check the one not associated with the device was not deleted
|
||||||
|
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||||
|
self.assertEqual(self.user_id, user["name"])
|
||||||
|
|
||||||
|
# now delete the rest
|
||||||
|
yield self.store.user_delete_access_tokens(
|
||||||
|
self.user_id, delete_refresh_tokens=True)
|
||||||
|
|
||||||
|
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||||
|
self.assertIsNone(user,
|
||||||
|
"access token was not deleted without device_id")
|
||||||
|
|
||||||
|
|
||||||
class TokenGenerator:
|
class TokenGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -17,13 +17,18 @@ from twisted.trial import unittest
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
# logging doesn't have a "don't log anything at all EVARRRR setting,
|
# logging doesn't have a "don't log anything at all EVARRRR setting,
|
||||||
# but since the highest value is 50, 1000000 should do ;)
|
# but since the highest value is 50, 1000000 should do ;)
|
||||||
NEVER = 1000000
|
NEVER = 1000000
|
||||||
|
|
||||||
logging.getLogger().addHandler(logging.StreamHandler())
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(logging.Formatter(
|
||||||
|
"%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
|
||||||
|
))
|
||||||
|
logging.getLogger().addHandler(handler)
|
||||||
logging.getLogger().setLevel(NEVER)
|
logging.getLogger().setLevel(NEVER)
|
||||||
|
logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
|
||||||
|
logging.getLogger("synapse.storage.txn").setLevel(NEVER)
|
||||||
|
|
||||||
|
|
||||||
def around(target):
|
def around(target):
|
||||||
|
@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
logging.getLogger().setLevel(level)
|
logging.getLogger().setLevel(level)
|
||||||
# Don't set SQL logging
|
|
||||||
logging.getLogger("synapse.storage").setLevel(old_level)
|
|
||||||
return orig()
|
return orig()
|
||||||
|
|
||||||
def assertObjectHasAttributes(self, attrs, obj):
|
def assertObjectHasAttributes(self, attrs, obj):
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.federation.transport import server
|
from synapse.federation.transport import server
|
||||||
from synapse.types import Requester
|
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
@ -512,7 +511,3 @@ class DeferredMockCallable(object):
|
||||||
"call(%s)" % _format_call(c[0], c[1]) for c in calls
|
"call(%s)" % _format_call(c[0], c[1]) for c in calls
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def requester_for_user(user):
|
|
||||||
return Requester(user, None, False)
|
|
||||||
|
|
Loading…
Reference in New Issue