Convert DeviceLastConnectionInfo to attrs. (#16507)

To improve type safety & memory usage.
This commit is contained in:
Patrick Cloke 2023-10-17 08:47:42 -04:00 committed by GitHub
parent 77dfc1f939
commit 6ad1f9eac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 103 deletions

1
changelog.d/16507.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -14,17 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes from synapse.api.constants import EduTypes, EventTypes
@ -41,6 +31,7 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping, JsonMapping,
@ -1008,14 +999,14 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips( def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
) -> None: ) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]))
device.update( device.update(
{ {
"last_seen_user_agent": ip.get("user_agent"), "last_seen_user_agent": ip.user_agent if ip else None,
"last_seen_ts": ip.get("last_seen"), "last_seen_ts": ip.last_seen if ip else None,
"last_seen_ip": ip.get("ip"), "last_seen_ip": ip.ip if ip else None,
} }
) )

View File

@ -15,6 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr
from typing_extensions import TypedDict from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -42,7 +43,8 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
class DeviceLastConnectionInfo(TypedDict): @attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination""" """Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table # These types must match the columns in the `devices` table
@ -499,24 +501,29 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user device_id: If None fetches all devices for the user
Returns: Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
keys giving the column names from the devices table.
""" """
keyvalues = {"user_id": user_id} keyvalues = {"user_id": user_id}
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
res = cast( res = await self.db_pool.simple_select_list(
List[DeviceLastConnectionInfo], table="devices",
await self.db_pool.simple_select_list( keyvalues=keyvalues,
table="devices", retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
) )
return {(d["user_id"], d["device_id"]): d for d in res} return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
)
for d in res
}
async def _get_user_ip_and_agents_from_database( async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0 self, user: UserID, since_ts: int = 0
@ -683,8 +690,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
device_id: If None fetches all devices for the user device_id: If None fetches all devices for the user
Returns: Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
keys giving the column names from the devices table.
""" """
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id) ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
@ -705,13 +711,13 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
continue continue
if not device_id or did == device_id: if not device_id or did == device_id:
ret[(user_id, did)] = { ret[(user_id, did)] = DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"ip": ip, ip=ip,
"user_agent": user_agent, user_agent=user_agent,
"device_id": did, device_id=did,
"last_seen": last_seen, last_seen=last_seen,
} )
return ret return ret
async def get_user_ip_and_agents( async def get_user_ip_and_agents(

View File

@ -24,7 +24,10 @@ import synapse.rest.admin
from synapse.http.site import XForwardedForRequest from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import (
LAST_SEEN_GRANULARITY,
DeviceLastConnectionInfo,
)
from synapse.types import UserID from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
@ -65,15 +68,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
r = result[(user_id, device_id)] r = result[(user_id, device_id)]
self.assertLessEqual( self.assertEqual(
{ DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id, device_id=device_id,
"ip": "ip", ip="ip",
"user_agent": "user_agent", user_agent="user_agent",
"last_seen": 12345678000, last_seen=12345678000,
}.items(), ),
r.items(), r,
) )
def test_insert_new_client_ip_none_device_id(self) -> None: def test_insert_new_client_ip_none_device_id(self) -> None:
@ -201,13 +204,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
result, result,
{ {
(user_id, device_id): { (user_id, device_id): DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id, device_id=device_id,
"ip": "ip", ip="ip",
"user_agent": "user_agent", user_agent="user_agent",
"last_seen": 12345678000, last_seen=12345678000,
}, ),
}, },
) )
@ -292,20 +295,20 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
result, result,
{ {
(user_id, device_id_1): { (user_id, device_id_1): DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id_1, device_id=device_id_1,
"ip": "ip_1", ip="ip_1",
"user_agent": "user_agent_1", user_agent="user_agent_1",
"last_seen": 12345678000, last_seen=12345678000,
}, ),
(user_id, device_id_2): { (user_id, device_id_2): DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id_2, device_id=device_id_2,
"ip": "ip_2", ip="ip_2",
"user_agent": "user_agent_3", user_agent="user_agent_3",
"last_seen": 12345688000 + LAST_SEEN_GRANULARITY, last_seen=12345688000 + LAST_SEEN_GRANULARITY,
}, ),
}, },
) )
@ -526,15 +529,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
r = result[(user_id, device_id)] r = result[(user_id, device_id)]
self.assertLessEqual( self.assertEqual(
{ DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id, device_id=device_id,
"ip": None, ip=None,
"user_agent": None, user_agent=None,
"last_seen": None, last_seen=None,
}.items(), ),
r.items(), r,
) )
# Register the background update to run again. # Register the background update to run again.
@ -561,15 +564,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
r = result[(user_id, device_id)] r = result[(user_id, device_id)]
self.assertLessEqual( self.assertEqual(
{ DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id, device_id=device_id,
"ip": "ip", ip="ip",
"user_agent": "user_agent", user_agent="user_agent",
"last_seen": 0, last_seen=0,
}.items(), ),
r.items(), r,
) )
def test_old_user_ips_pruned(self) -> None: def test_old_user_ips_pruned(self) -> None:
@ -640,15 +643,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
) )
r = result2[(user_id, device_id)] r = result2[(user_id, device_id)]
self.assertLessEqual( self.assertEqual(
{ DeviceLastConnectionInfo(
"user_id": user_id, user_id=user_id,
"device_id": device_id, device_id=device_id,
"ip": "ip", ip="ip",
"user_agent": "user_agent", user_agent="user_agent",
"last_seen": 0, last_seen=0,
}.items(), ),
r.items(), r,
) )
def test_invalid_user_agents_are_ignored(self) -> None: def test_invalid_user_agents_are_ignored(self) -> None:
@ -777,13 +780,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.store.get_last_client_ip_by_device(self.user_id, device_id) self.store.get_last_client_ip_by_device(self.user_id, device_id)
) )
r = result[(self.user_id, device_id)] r = result[(self.user_id, device_id)]
self.assertLessEqual( self.assertEqual(
{ DeviceLastConnectionInfo(
"user_id": self.user_id, user_id=self.user_id,
"device_id": device_id, device_id=device_id,
"ip": expected_ip, ip=expected_ip,
"user_agent": "Mozzila pizza", user_agent="Mozzila pizza",
"last_seen": 123456100, last_seen=123456100,
}.items(), ),
r.items(), r,
) )