Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. (#12775)
This commit is contained in:
parent
3d8839c30c
commit
19d79b6ebe
|
@ -0,0 +1 @@
|
||||||
|
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.
|
|
@ -288,7 +288,6 @@ class StateHandler:
|
||||||
#
|
#
|
||||||
# first of all, figure out the state before the event
|
# first of all, figure out the state before the event
|
||||||
#
|
#
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
# if we're given the state before the event, then we use that
|
# if we're given the state before the event, then we use that
|
||||||
state_ids_before_event: StateMap[str] = {
|
state_ids_before_event: StateMap[str] = {
|
||||||
|
@ -419,33 +418,37 @@ class StateHandler:
|
||||||
"""
|
"""
|
||||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||||
|
|
||||||
# map from state group id to the state in that state group (where
|
state_groups = await self.state_store.get_state_group_for_events(event_ids)
|
||||||
# 'state' is a map from state key to event id)
|
|
||||||
# dict[int, dict[(str, str), str]]
|
state_group_ids = state_groups.values()
|
||||||
state_groups_ids = await self.state_store.get_state_groups_ids(
|
|
||||||
room_id, event_ids
|
# check if each event has same state group id, if so there's no state to resolve
|
||||||
|
state_group_ids_set = set(state_group_ids)
|
||||||
|
if len(state_group_ids_set) == 1:
|
||||||
|
(state_group_id,) = state_group_ids_set
|
||||||
|
state = await self.state_store.get_state_for_groups(state_group_ids_set)
|
||||||
|
prev_group, delta_ids = await self.state_store.get_state_group_delta(
|
||||||
|
state_group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(state_groups_ids) == 0:
|
|
||||||
return _StateCacheEntry(state={}, state_group=None)
|
|
||||||
elif len(state_groups_ids) == 1:
|
|
||||||
name, state_list = list(state_groups_ids.items()).pop()
|
|
||||||
|
|
||||||
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
|
|
||||||
|
|
||||||
return _StateCacheEntry(
|
return _StateCacheEntry(
|
||||||
state=state_list,
|
state=state[state_group_id],
|
||||||
state_group=name,
|
state_group=state_group_id,
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
|
elif len(state_group_ids_set) == 0:
|
||||||
|
return _StateCacheEntry(state={}, state_group=None)
|
||||||
|
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
|
|
||||||
|
state_to_resolve = await self.state_store.get_state_for_groups(
|
||||||
|
state_group_ids_set
|
||||||
|
)
|
||||||
|
|
||||||
result = await self._state_resolution_handler.resolve_state_groups(
|
result = await self._state_resolution_handler.resolve_state_groups(
|
||||||
room_id,
|
room_id,
|
||||||
room_version,
|
room_version,
|
||||||
state_groups_ids,
|
state_to_resolve,
|
||||||
None,
|
None,
|
||||||
state_res_store=StateResolutionStore(self.store),
|
state_res_store=StateResolutionStore(self.store),
|
||||||
)
|
)
|
||||||
|
|
|
@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
group: int,
|
group: int,
|
||||||
state_filter: StateFilter,
|
state_filter: StateFilter,
|
||||||
) -> Tuple[MutableStateMap[str], bool]:
|
) -> Tuple[MutableStateMap[str], bool]:
|
||||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
"""Checks if group is in cache. See `get_state_for_groups`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache: the state group cache to use
|
cache: the state group cache to use
|
||||||
|
|
|
@ -586,7 +586,7 @@ class StateGroupStorage:
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||||
|
@ -602,7 +602,7 @@ class StateGroupStorage:
|
||||||
Returns:
|
Returns:
|
||||||
Resolves to a map of (type, state_key) -> event_id
|
Resolves to a map of (type, state_key) -> event_id
|
||||||
"""
|
"""
|
||||||
group_to_state = await self._get_state_for_groups((state_group,))
|
group_to_state = await self.get_state_for_groups((state_group,))
|
||||||
|
|
||||||
return group_to_state[state_group]
|
return group_to_state[state_group]
|
||||||
|
|
||||||
|
@ -675,7 +675,7 @@ class StateGroupStorage:
|
||||||
RuntimeError if we don't have a state group for one or more of the events
|
RuntimeError if we don't have a state group for one or more of the events
|
||||||
(ie they are outliers or unknown)
|
(ie they are outliers or unknown)
|
||||||
"""
|
"""
|
||||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
|
@ -716,7 +716,7 @@ class StateGroupStorage:
|
||||||
RuntimeError if we don't have a state group for one or more of the events
|
RuntimeError if we don't have a state group for one or more of the events
|
||||||
(ie they are outliers or unknown)
|
(ie they are outliers or unknown)
|
||||||
"""
|
"""
|
||||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
|
@ -774,7 +774,7 @@ class StateGroupStorage:
|
||||||
)
|
)
|
||||||
return state_map[event_id]
|
return state_map[event_id]
|
||||||
|
|
||||||
def _get_state_for_groups(
|
def get_state_for_groups(
|
||||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
|
@ -792,7 +792,7 @@ class StateGroupStorage:
|
||||||
groups, state_filter or StateFilter.all()
|
groups, state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_state_group_for_events(
|
async def get_state_group_for_events(
|
||||||
self,
|
self,
|
||||||
event_ids: Collection[str],
|
event_ids: Collection[str],
|
||||||
await_full_state: bool = True,
|
await_full_state: bool = True,
|
||||||
|
|
|
@ -129,6 +129,19 @@ class _DummyStore:
|
||||||
async def get_room_version_id(self, room_id):
|
async def get_room_version_id(self, room_id):
|
||||||
return RoomVersions.V1.identifier
|
return RoomVersions.V1.identifier
|
||||||
|
|
||||||
|
async def get_state_group_for_events(self, event_ids):
|
||||||
|
res = {}
|
||||||
|
for event in event_ids:
|
||||||
|
res[event] = self._event_to_state_group[event]
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def get_state_for_groups(self, groups):
|
||||||
|
res = {}
|
||||||
|
for group in groups:
|
||||||
|
state = self._group_to_state[group]
|
||||||
|
res[group] = state
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class DictObj(dict):
|
class DictObj(dict):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue