Merge branch 'erikj/shared_secret' into erikj/test2
This commit is contained in:
commit
a17e7caeb7
|
@ -25,18 +25,26 @@ import urllib2
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def request_registration(user, password, server_location, shared_secret):
|
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||||
mac = hmac.new(
|
mac = hmac.new(
|
||||||
key=shared_secret,
|
key=shared_secret,
|
||||||
msg=user,
|
|
||||||
digestmod=hashlib.sha1,
|
digestmod=hashlib.sha1,
|
||||||
).hexdigest()
|
)
|
||||||
|
|
||||||
|
mac.update(user)
|
||||||
|
mac.update("\x00")
|
||||||
|
mac.update(password)
|
||||||
|
mac.update("\x00")
|
||||||
|
mac.update("admin" if admin else "notadmin")
|
||||||
|
|
||||||
|
mac = mac.hexdigest()
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"password": password,
|
"password": password,
|
||||||
"mac": mac,
|
"mac": mac,
|
||||||
"type": "org.matrix.login.shared_secret",
|
"type": "org.matrix.login.shared_secret",
|
||||||
|
"admin": admin,
|
||||||
}
|
}
|
||||||
|
|
||||||
server_location = server_location.rstrip("/")
|
server_location = server_location.rstrip("/")
|
||||||
|
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def register_new_user(user, password, server_location, shared_secret):
|
def register_new_user(user, password, server_location, shared_secret, admin):
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
default_user = getpass.getuser()
|
default_user = getpass.getuser()
|
||||||
|
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
|
||||||
print "Passwords do not match"
|
print "Passwords do not match"
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
request_registration(user, password, server_location, shared_secret)
|
if not admin:
|
||||||
|
admin = raw_input("Make admin [no]: ")
|
||||||
|
if admin in ("y", "yes", "true"):
|
||||||
|
admin = True
|
||||||
|
else:
|
||||||
|
admin = False
|
||||||
|
|
||||||
|
request_registration(user, password, server_location, shared_secret, bool(admin))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -119,6 +134,11 @@ if __name__ == "__main__":
|
||||||
default=None,
|
default=None,
|
||||||
help="New password for user. Will prompt if omitted.",
|
help="New password for user. Will prompt if omitted.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-a", "--admin",
|
||||||
|
action="store_true",
|
||||||
|
help="Register new user as an admin. Will prompt if omitted.",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_mutually_exclusive_group(required=True)
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
@ -151,4 +171,4 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
secret = args.shared_secret
|
secret = args.shared_secret
|
||||||
|
|
||||||
register_new_user(args.user, args.password, args.server_url, secret)
|
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
|
||||||
|
|
|
@ -42,8 +42,9 @@ class Codes(object):
|
||||||
TOO_LARGE = "M_TOO_LARGE"
|
TOO_LARGE = "M_TOO_LARGE"
|
||||||
EXCLUSIVE = "M_EXCLUSIVE"
|
EXCLUSIVE = "M_EXCLUSIVE"
|
||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||||
|
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
|
|
@ -23,10 +23,14 @@ class PasswordConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
password_config = config.get("password_config", {})
|
password_config = config.get("password_config", {})
|
||||||
self.password_enabled = password_config.get("enabled", True)
|
self.password_enabled = password_config.get("enabled", True)
|
||||||
|
self.password_pepper = password_config.get("pepper", "")
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """
|
||||||
# Enable password for login.
|
# Enable password for login.
|
||||||
password_config:
|
password_config:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
# Change to a secret random string.
|
||||||
|
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||||
|
#pepper: ""
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -750,7 +750,8 @@ class AuthHandler(BaseHandler):
|
||||||
Returns:
|
Returns:
|
||||||
Hashed password (str).
|
Hashed password (str).
|
||||||
"""
|
"""
|
||||||
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
|
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||||
|
bcrypt.gensalt(self.bcrypt_rounds))
|
||||||
|
|
||||||
def validate_hash(self, password, stored_hash):
|
def validate_hash(self, password, stored_hash):
|
||||||
"""Validates that self.hash(password) == stored_hash.
|
"""Validates that self.hash(password) == stored_hash.
|
||||||
|
@ -763,6 +764,7 @@ class AuthHandler(BaseHandler):
|
||||||
Whether self.hash(password) == stored_hash (bool).
|
Whether self.hash(password) == stored_hash (bool).
|
||||||
"""
|
"""
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash
|
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||||
|
stored_hash.encode('utf-8')) == stored_hash
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import (
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
|
||||||
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _should_trust_id_server(self, id_server):
|
||||||
|
if id_server not in self.trusted_id_servers:
|
||||||
|
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||||
|
logger.warn(
|
||||||
|
"Trusting untrustworthy ID server %r even though it isn't"
|
||||||
|
" in the trusted id list for testing because"
|
||||||
|
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||||
|
" is set in the config",
|
||||||
|
id_server,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def threepid_from_creds(self, creds):
|
def threepid_from_creds(self, creds):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "No client_secret in creds")
|
raise SynapseError(400, "No client_secret in creds")
|
||||||
|
|
||||||
if id_server not in self.trusted_id_servers:
|
if not self._should_trust_id_server(id_server):
|
||||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
logger.warn(
|
||||||
logger.warn(
|
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||||
"Trusting untrustworthy ID server %r even though it isn't"
|
'credentials', id_server
|
||||||
" in the trusted id list for testing because"
|
)
|
||||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
defer.returnValue(None)
|
||||||
" is set in the config",
|
|
||||||
id_server,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
|
||||||
'credentials', id_server)
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
try:
|
try:
|
||||||
|
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
|
||||||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if not self._should_trust_id_server(id_server):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Untrusted ID server '%s'" % id_server,
|
||||||
|
Codes.SERVER_NOT_TRUSTED
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'email': email,
|
'email': email,
|
||||||
'client_secret': client_secret,
|
'client_secret': client_secret,
|
||||||
|
|
|
@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
password=None,
|
password=None,
|
||||||
generate_token=True,
|
generate_token=True,
|
||||||
guest_access_token=None,
|
guest_access_token=None,
|
||||||
make_guest=False
|
make_guest=False,
|
||||||
|
admin=False,
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
# If the user was a guest then they already have a profile
|
# If the user was a guest then they already have a profile
|
||||||
None if was_guest else user.localpart
|
None if was_guest else user.localpart
|
||||||
),
|
),
|
||||||
|
admin=admin,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# autogen a sequential user ID
|
# autogen a sequential user ID
|
||||||
|
|
|
@ -324,6 +324,14 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
user = register_json["user"].encode("utf-8")
|
user = register_json["user"].encode("utf-8")
|
||||||
|
password = register_json["password"].encode("utf-8")
|
||||||
|
admin = register_json.get("admin", None)
|
||||||
|
|
||||||
|
# Its important to check as we use null bytes as HMAC field separators
|
||||||
|
if "\x00" in user:
|
||||||
|
raise SynapseError(400, "Invalid user")
|
||||||
|
if "\x00" in password:
|
||||||
|
raise SynapseError(400, "Invalid password")
|
||||||
|
|
||||||
# str() because otherwise hmac complains that 'unicode' does not
|
# str() because otherwise hmac complains that 'unicode' does not
|
||||||
# have the buffer interface
|
# have the buffer interface
|
||||||
|
@ -331,17 +339,21 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
want_mac = hmac.new(
|
want_mac = hmac.new(
|
||||||
key=self.hs.config.registration_shared_secret,
|
key=self.hs.config.registration_shared_secret,
|
||||||
msg=user,
|
|
||||||
digestmod=sha1,
|
digestmod=sha1,
|
||||||
).hexdigest()
|
)
|
||||||
|
want_mac.update(user)
|
||||||
password = register_json["password"].encode("utf-8")
|
want_mac.update("\x00")
|
||||||
|
want_mac.update(password)
|
||||||
|
want_mac.update("\x00")
|
||||||
|
want_mac.update("admin" if admin else "notadmin")
|
||||||
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
if compare_digest(want_mac, got_mac):
|
if compare_digest(want_mac, got_mac):
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.register(
|
user_id, token = yield handler.register(
|
||||||
localpart=user,
|
localpart=user,
|
||||||
password=password,
|
password=password,
|
||||||
|
admin=bool(admin),
|
||||||
)
|
)
|
||||||
self._remove_session(session)
|
self._remove_session(session)
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
from synapse.types import RoomStreamToken
|
||||||
|
from .stream import lower_bound
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
|
|
||||||
stream_ordering = results[0][0]
|
stream_ordering = results[0][0]
|
||||||
topological_ordering = results[0][1]
|
topological_ordering = results[0][1]
|
||||||
|
token = RoomStreamToken(
|
||||||
|
topological_ordering, stream_ordering
|
||||||
|
)
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT sum(notif), sum(highlight)"
|
"SELECT sum(notif), sum(highlight)"
|
||||||
|
@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
" WHERE"
|
" WHERE"
|
||||||
" user_id = ?"
|
" user_id = ?"
|
||||||
" AND room_id = ?"
|
" AND room_id = ?"
|
||||||
" AND ("
|
" AND %s"
|
||||||
" topological_ordering > ?"
|
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||||
" OR (topological_ordering = ? AND stream_ordering > ?)"
|
|
||||||
")"
|
txn.execute(sql, (user_id, room_id))
|
||||||
)
|
|
||||||
txn.execute(sql, (
|
|
||||||
user_id, room_id,
|
|
||||||
topological_ordering, topological_ordering, stream_ordering
|
|
||||||
))
|
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
if row:
|
if row:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, user_id, token, password_hash,
|
def register(self, user_id, token, password_hash,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None):
|
create_profile_with_localpart=None, admin=False):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
|
admin
|
||||||
)
|
)
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
self.get_user_by_id.invalidate((user_id,))
|
||||||
self.is_guest.invalidate((user_id,))
|
self.is_guest.invalidate((user_id,))
|
||||||
|
@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
make_guest,
|
make_guest,
|
||||||
appservice_id,
|
appservice_id,
|
||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
|
admin,
|
||||||
):
|
):
|
||||||
now = int(self.clock.time())
|
now = int(self.clock.time())
|
||||||
|
|
||||||
|
@ -125,29 +127,33 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
txn.execute("UPDATE users SET"
|
self._simple_update_one_txn(
|
||||||
" password_hash = ?,"
|
txn,
|
||||||
" upgrade_ts = ?,"
|
"users",
|
||||||
" is_guest = ?"
|
keyvalues={
|
||||||
" WHERE name = ?",
|
"name": user_id,
|
||||||
[password_hash, now, 1 if make_guest else 0, user_id])
|
},
|
||||||
|
updatevalues={
|
||||||
|
"password_hash": password_hash,
|
||||||
|
"upgrade_ts": now,
|
||||||
|
"is_guest": 1 if make_guest else 0,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"admin": 1 if admin else 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
txn.execute("INSERT INTO users "
|
self._simple_insert_txn(
|
||||||
"("
|
txn,
|
||||||
" name,"
|
"users",
|
||||||
" password_hash,"
|
values={
|
||||||
" creation_ts,"
|
"name": user_id,
|
||||||
" is_guest,"
|
"password_hash": password_hash,
|
||||||
" appservice_id"
|
"creation_ts": now,
|
||||||
") "
|
"is_guest": 1 if make_guest else 0,
|
||||||
"VALUES (?,?,?,?,?)",
|
"appservice_id": appservice_id,
|
||||||
[
|
"admin": 1 if admin else 0,
|
||||||
user_id,
|
}
|
||||||
password_hash,
|
)
|
||||||
now,
|
|
||||||
1 if make_guest else 0,
|
|
||||||
appservice_id,
|
|
||||||
])
|
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
raise StoreError(
|
raise StoreError(
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
|
|
|
@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -54,25 +55,43 @@ _STREAM_TOKEN = "stream"
|
||||||
_TOPOLOGICAL_TOKEN = "topological"
|
_TOPOLOGICAL_TOKEN = "topological"
|
||||||
|
|
||||||
|
|
||||||
def lower_bound(token):
|
def lower_bound(token, engine, inclusive=False):
|
||||||
|
inclusive = "=" if inclusive else ""
|
||||||
if token.topological is None:
|
if token.topological is None:
|
||||||
return "(%d < %s)" % (token.stream, "stream_ordering")
|
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||||
else:
|
else:
|
||||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
if isinstance(engine, PostgresEngine):
|
||||||
|
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
|
||||||
|
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
|
||||||
|
# use the later form when running against postgres.
|
||||||
|
return "((%d,%d) <%s (%s,%s))" % (
|
||||||
|
token.topological, token.stream, inclusive,
|
||||||
|
"topological_ordering", "stream_ordering",
|
||||||
|
)
|
||||||
|
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
|
||||||
token.topological, "topological_ordering",
|
token.topological, "topological_ordering",
|
||||||
token.topological, "topological_ordering",
|
token.topological, "topological_ordering",
|
||||||
token.stream, "stream_ordering",
|
token.stream, inclusive, "stream_ordering",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def upper_bound(token):
|
def upper_bound(token, engine, inclusive=True):
|
||||||
|
inclusive = "=" if inclusive else ""
|
||||||
if token.topological is None:
|
if token.topological is None:
|
||||||
return "(%d >= %s)" % (token.stream, "stream_ordering")
|
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||||
else:
|
else:
|
||||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
if isinstance(engine, PostgresEngine):
|
||||||
|
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
|
||||||
|
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
|
||||||
|
# use the later form when running against postgres.
|
||||||
|
return "((%d,%d) >%s (%s,%s))" % (
|
||||||
|
token.topological, token.stream, inclusive,
|
||||||
|
"topological_ordering", "stream_ordering",
|
||||||
|
)
|
||||||
|
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
|
||||||
token.topological, "topological_ordering",
|
token.topological, "topological_ordering",
|
||||||
token.topological, "topological_ordering",
|
token.topological, "topological_ordering",
|
||||||
token.stream, "stream_ordering",
|
token.stream, inclusive, "stream_ordering",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -308,18 +327,22 @@ class StreamStore(SQLBaseStore):
|
||||||
args = [False, room_id]
|
args = [False, room_id]
|
||||||
if direction == 'b':
|
if direction == 'b':
|
||||||
order = "DESC"
|
order = "DESC"
|
||||||
bounds = upper_bound(RoomStreamToken.parse(from_key))
|
bounds = upper_bound(
|
||||||
|
RoomStreamToken.parse(from_key), self.database_engine
|
||||||
|
)
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (bounds, lower_bound(
|
||||||
bounds, lower_bound(RoomStreamToken.parse(to_key))
|
RoomStreamToken.parse(to_key), self.database_engine
|
||||||
)
|
))
|
||||||
else:
|
else:
|
||||||
order = "ASC"
|
order = "ASC"
|
||||||
bounds = lower_bound(RoomStreamToken.parse(from_key))
|
bounds = lower_bound(
|
||||||
|
RoomStreamToken.parse(from_key), self.database_engine
|
||||||
|
)
|
||||||
if to_key:
|
if to_key:
|
||||||
bounds = "%s AND %s" % (
|
bounds = "%s AND %s" % (bounds, upper_bound(
|
||||||
bounds, upper_bound(RoomStreamToken.parse(to_key))
|
RoomStreamToken.parse(to_key), self.database_engine
|
||||||
)
|
))
|
||||||
|
|
||||||
if int(limit) > 0:
|
if int(limit) > 0:
|
||||||
args.append(int(limit))
|
args.append(int(limit))
|
||||||
|
@ -586,32 +609,60 @@ class StreamStore(SQLBaseStore):
|
||||||
retcols=["stream_ordering", "topological_ordering"],
|
retcols=["stream_ordering", "topological_ordering"],
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_ordering = results["stream_ordering"]
|
token = RoomStreamToken(
|
||||||
topological_ordering = results["topological_ordering"]
|
results["topological_ordering"],
|
||||||
|
results["stream_ordering"],
|
||||||
query_before = (
|
|
||||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
|
||||||
" WHERE room_id = ? AND (topological_ordering < ?"
|
|
||||||
" OR (topological_ordering = ? AND stream_ordering < ?))"
|
|
||||||
" ORDER BY topological_ordering DESC, stream_ordering DESC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_after = (
|
if isinstance(self.database_engine, Sqlite3Engine):
|
||||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
# SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
|
||||||
" WHERE room_id = ? AND (topological_ordering > ?"
|
# So we give pass it to SQLite3 as the UNION ALL of the two queries.
|
||||||
" OR (topological_ordering = ? AND stream_ordering > ?))"
|
|
||||||
" ORDER BY topological_ordering ASC, stream_ordering ASC"
|
|
||||||
" LIMIT ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
query_before = (
|
||||||
query_before,
|
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
(
|
" WHERE room_id = ? AND topological_ordering < ?"
|
||||||
room_id, topological_ordering, topological_ordering,
|
" UNION ALL"
|
||||||
stream_ordering, before_limit,
|
" SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
|
||||||
|
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
)
|
before_args = (
|
||||||
|
room_id, token.topological,
|
||||||
|
room_id, token.topological, token.stream,
|
||||||
|
before_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_after = (
|
||||||
|
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" WHERE room_id = ? AND topological_ordering > ?"
|
||||||
|
" UNION ALL"
|
||||||
|
" SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
|
||||||
|
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
|
||||||
|
)
|
||||||
|
after_args = (
|
||||||
|
room_id, token.topological,
|
||||||
|
room_id, token.topological, token.stream,
|
||||||
|
after_limit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_before = (
|
||||||
|
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" WHERE room_id = ? AND %s"
|
||||||
|
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
|
||||||
|
) % (upper_bound(token, self.database_engine, inclusive=False),)
|
||||||
|
|
||||||
|
before_args = (room_id, before_limit)
|
||||||
|
|
||||||
|
query_after = (
|
||||||
|
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" WHERE room_id = ? AND %s"
|
||||||
|
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
|
||||||
|
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||||
|
|
||||||
|
after_args = (room_id, after_limit)
|
||||||
|
|
||||||
|
txn.execute(query_before, before_args)
|
||||||
|
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
events_before = [r["event_id"] for r in rows]
|
events_before = [r["event_id"] for r in rows]
|
||||||
|
@ -623,17 +674,11 @@ class StreamStore(SQLBaseStore):
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
start_token = str(RoomStreamToken(
|
start_token = str(RoomStreamToken(
|
||||||
topological_ordering,
|
token.topological,
|
||||||
stream_ordering - 1,
|
token.stream - 1,
|
||||||
))
|
))
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(query_after, after_args)
|
||||||
query_after,
|
|
||||||
(
|
|
||||||
room_id, topological_ordering, topological_ordering,
|
|
||||||
stream_ordering, after_limit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
events_after = [r["event_id"] for r in rows]
|
events_after = [r["event_id"] for r in rows]
|
||||||
|
@ -644,10 +689,7 @@ class StreamStore(SQLBaseStore):
|
||||||
rows[-1]["stream_ordering"],
|
rows[-1]["stream_ordering"],
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
end_token = str(RoomStreamToken(
|
end_token = str(token)
|
||||||
topological_ordering,
|
|
||||||
stream_ordering,
|
|
||||||
))
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"before": {
|
"before": {
|
||||||
|
|
Loading…
Reference in New Issue