Merge pull request #1117 from matrix-org/erikj/fix_state

Ensure we don't mutate state cache entries
This commit is contained in:
Erik Johnston 2016-09-14 16:50:37 +01:00 committed by GitHub
commit 264a48aedf
3 changed files with 30 additions and 20 deletions

View File

@ -1585,10 +1585,12 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })
@ -1670,10 +1672,12 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })

View File

@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
from frozendict import frozendict
import logging import logging
import hashlib import hashlib
@ -58,11 +59,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
self.state = state self.state = frozendict(state)
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group self.prev_group = prev_group
self.delta_ids = delta_ids self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
# The `state_id` is a unique ID we generate that can be used as ID for # The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the # this collection of state. Usually this would be the same as the
@ -237,13 +238,7 @@ class StateHandler(object):
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
replaces = context.prev_state_ids[key] replaces = context.prev_state_ids[key]
@ -255,10 +250,15 @@ class StateHandler(object):
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids
if context.delta_ids is not None: if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else: else:
context.current_state_ids = context.prev_state_ids if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids

View File

@ -817,8 +817,13 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_index_state(self, progress, batch_size): def _background_index_state(self, progress, batch_size):
def reindex_txn(txn): def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
try:
txn = conn.cursor()
txn.execute( txn.execute(
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx" "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)" " ON state_groups_state(state_group, type, state_key)"
@ -826,7 +831,10 @@ class StateStore(SQLBaseStore):
txn.execute( txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id" "DROP INDEX IF EXISTS state_groups_state_id"
) )
finally:
conn.set_session(autocommit=False)
else: else:
txn = conn.cursor()
txn.execute( txn.execute(
"CREATE INDEX state_groups_state_type_idx" "CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)" " ON state_groups_state(state_group, type, state_key)"
@ -835,9 +843,7 @@ class StateStore(SQLBaseStore):
"DROP INDEX IF EXISTS state_groups_state_id" "DROP INDEX IF EXISTS state_groups_state_id"
) )
yield self.runInteraction( yield self.runWithConnection(reindex_txn)
self.STATE_GROUP_INDEX_UPDATE_NAME, reindex_txn
)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)