mirror of
				https://github.com/matrix-org/synapse.git
				synced 2025-11-03 21:57:26 +00:00 
			
		
		
		
	Remove not needed database updates in modify user admin API (#10627)
This commit is contained in:
		
							parent
							
								
									0c3565da4c
								
							
						
					
					
						commit
						220f901229
					
				
							
								
								
									
										1
									
								
								changelog.d/10627.misc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/10627.misc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
Remove not needed database updates in modify user admin API.
 | 
			
		||||
@ -21,11 +21,15 @@ It returns a JSON body like the following:
 | 
			
		||||
    "threepids": [
 | 
			
		||||
        {
 | 
			
		||||
            "medium": "email",
 | 
			
		||||
            "address": "<user_mail_1>"
 | 
			
		||||
            "address": "<user_mail_1>",
 | 
			
		||||
            "added_at": 1586458409743,
 | 
			
		||||
            "validated_at": 1586458409743
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
            "medium": "email",
 | 
			
		||||
            "address": "<user_mail_2>"
 | 
			
		||||
            "address": "<user_mail_2>",
 | 
			
		||||
            "added_at": 1586458409743,
 | 
			
		||||
            "validated_at": 1586458409743
 | 
			
		||||
        }
 | 
			
		||||
    ],
 | 
			
		||||
    "avatar_url": "<avatar_url>",
 | 
			
		||||
 | 
			
		||||
@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
 | 
			
		||||
        if not isinstance(deactivate, bool):
 | 
			
		||||
            raise SynapseError(400, "'deactivated' parameter is not of type boolean")
 | 
			
		||||
 | 
			
		||||
        # convert into List[Tuple[str, str]]
 | 
			
		||||
        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
 | 
			
		||||
        if external_ids is not None:
 | 
			
		||||
            new_external_ids = []
 | 
			
		||||
            for external_id in external_ids:
 | 
			
		||||
                new_external_ids.append(
 | 
			
		||||
                    (external_id["auth_provider"], external_id["external_id"])
 | 
			
		||||
                )
 | 
			
		||||
            new_external_ids = {
 | 
			
		||||
                (external_id["auth_provider"], external_id["external_id"])
 | 
			
		||||
                for external_id in external_ids
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
 | 
			
		||||
        if threepids is not None:
 | 
			
		||||
            new_threepids = {
 | 
			
		||||
                (threepid["medium"], threepid["address"]) for threepid in threepids
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if user:  # modify user
 | 
			
		||||
            if "displayname" in body:
 | 
			
		||||
@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if threepids is not None:
 | 
			
		||||
                # remove old threepids from user
 | 
			
		||||
                old_threepids = await self.store.user_get_threepids(user_id)
 | 
			
		||||
                for threepid in old_threepids:
 | 
			
		||||
                # get changed threepids (added and removed)
 | 
			
		||||
                # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
 | 
			
		||||
                cur_threepids = {
 | 
			
		||||
                    (threepid["medium"], threepid["address"])
 | 
			
		||||
                    for threepid in await self.store.user_get_threepids(user_id)
 | 
			
		||||
                }
 | 
			
		||||
                add_threepids = new_threepids - cur_threepids
 | 
			
		||||
                del_threepids = cur_threepids - new_threepids
 | 
			
		||||
 | 
			
		||||
                # remove old threepids
 | 
			
		||||
                for medium, address in del_threepids:
 | 
			
		||||
                    try:
 | 
			
		||||
                        await self.auth_handler.delete_threepid(
 | 
			
		||||
                            user_id, threepid["medium"], threepid["address"], None
 | 
			
		||||
                            user_id, medium, address, None
 | 
			
		||||
                        )
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        logger.exception("Failed to remove threepids")
 | 
			
		||||
                        raise SynapseError(500, "Failed to remove threepids")
 | 
			
		||||
 | 
			
		||||
                # add new threepids to user
 | 
			
		||||
                # add new threepids
 | 
			
		||||
                current_time = self.hs.get_clock().time_msec()
 | 
			
		||||
                for threepid in threepids:
 | 
			
		||||
                for medium, address in add_threepids:
 | 
			
		||||
                    await self.auth_handler.add_threepid(
 | 
			
		||||
                        user_id, threepid["medium"], threepid["address"], current_time
 | 
			
		||||
                        user_id, medium, address, current_time
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            if external_ids is not None:
 | 
			
		||||
                # get changed external_ids (added and removed)
 | 
			
		||||
                cur_external_ids = await self.store.get_external_ids_by_user(user_id)
 | 
			
		||||
                add_external_ids = set(new_external_ids) - set(cur_external_ids)
 | 
			
		||||
                del_external_ids = set(cur_external_ids) - set(new_external_ids)
 | 
			
		||||
                cur_external_ids = set(
 | 
			
		||||
                    await self.store.get_external_ids_by_user(user_id)
 | 
			
		||||
                )
 | 
			
		||||
                add_external_ids = new_external_ids - cur_external_ids
 | 
			
		||||
                del_external_ids = cur_external_ids - new_external_ids
 | 
			
		||||
 | 
			
		||||
                # remove old external_ids
 | 
			
		||||
                for auth_provider, external_id in del_external_ids:
 | 
			
		||||
@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
 | 
			
		||||
 | 
			
		||||
            if threepids is not None:
 | 
			
		||||
                current_time = self.hs.get_clock().time_msec()
 | 
			
		||||
                for threepid in threepids:
 | 
			
		||||
                for medium, address in new_threepids:
 | 
			
		||||
                    await self.auth_handler.add_threepid(
 | 
			
		||||
                        user_id, threepid["medium"], threepid["address"], current_time
 | 
			
		||||
                        user_id, medium, address, current_time
 | 
			
		||||
                    )
 | 
			
		||||
                    if (
 | 
			
		||||
                        self.hs.config.email_enable_notifs
 | 
			
		||||
@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
 | 
			
		||||
                            kind="email",
 | 
			
		||||
                            app_id="m.email",
 | 
			
		||||
                            app_display_name="Email Notifications",
 | 
			
		||||
                            device_display_name=threepid["address"],
 | 
			
		||||
                            pushkey=threepid["address"],
 | 
			
		||||
                            device_display_name=address,
 | 
			
		||||
                            pushkey=address,
 | 
			
		||||
                            lang=None,  # We don't know a user's language here
 | 
			
		||||
                            data={},
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		||||
        )
 | 
			
		||||
        return user_id
 | 
			
		||||
 | 
			
		||||
    def get_user_id_by_threepid_txn(self, txn, medium, address):
 | 
			
		||||
    def get_user_id_by_threepid_txn(
 | 
			
		||||
        self, txn, medium: str, address: str
 | 
			
		||||
    ) -> Optional[str]:
 | 
			
		||||
        """Returns user id from threepid
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            txn (cursor):
 | 
			
		||||
            medium (str): threepid medium e.g. email
 | 
			
		||||
            address (str): threepid address e.g. me@example.com
 | 
			
		||||
            medium: threepid medium e.g. email
 | 
			
		||||
            address: threepid address e.g. me@example.com
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str|None: user id or None if no user id/threepid mapping exists
 | 
			
		||||
            user id, or None if no user id/threepid mapping exists
 | 
			
		||||
        """
 | 
			
		||||
        ret = self.db_pool.simple_select_one_txn(
 | 
			
		||||
            txn,
 | 
			
		||||
@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		||||
            return ret["user_id"]
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
 | 
			
		||||
    async def user_add_threepid(
 | 
			
		||||
        self,
 | 
			
		||||
        user_id: str,
 | 
			
		||||
        medium: str,
 | 
			
		||||
        address: str,
 | 
			
		||||
        validated_at: int,
 | 
			
		||||
        added_at: int,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        await self.db_pool.simple_upsert(
 | 
			
		||||
            "user_threepids",
 | 
			
		||||
            {"medium": medium, "address": address},
 | 
			
		||||
            {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def user_get_threepids(self, user_id):
 | 
			
		||||
    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
 | 
			
		||||
        return await self.db_pool.simple_select_list(
 | 
			
		||||
            "user_threepids",
 | 
			
		||||
            {"user_id": user_id},
 | 
			
		||||
@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 | 
			
		||||
            "user_get_threepids",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def user_delete_threepid(self, user_id, medium, address) -> None:
 | 
			
		||||
    async def user_delete_threepid(
 | 
			
		||||
        self, user_id: str, medium: str, address: str
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        await self.db_pool.simple_delete(
 | 
			
		||||
            "user_threepids",
 | 
			
		||||
            keyvalues={"user_id": user_id, "medium": medium, "address": address},
 | 
			
		||||
 | 
			
		||||
@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
        self.assertEqual("Bob's name", channel.json_body["displayname"])
 | 
			
		||||
        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
 | 
			
		||||
        self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
 | 
			
		||||
        self.assertEqual(1, len(channel.json_body["threepids"]))
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            "external_id1", channel.json_body["external_ids"][0]["external_id"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(1, len(channel.json_body["external_ids"]))
 | 
			
		||||
        self.assertFalse(channel.json_body["admin"])
 | 
			
		||||
        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 | 
			
		||||
        self._check_fields(channel.json_body)
 | 
			
		||||
@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
        Test setting threepid for an other user.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Delete old and add new threepid to user
 | 
			
		||||
        # Add two threepids to user
 | 
			
		||||
        channel = self.make_request(
 | 
			
		||||
            "PUT",
 | 
			
		||||
            self.url_other_user,
 | 
			
		||||
            access_token=self.admin_user_tok,
 | 
			
		||||
            content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
 | 
			
		||||
            content={
 | 
			
		||||
                "threepids": [
 | 
			
		||||
                    {"medium": "email", "address": "bob1@bob.bob"},
 | 
			
		||||
                    {"medium": "email", "address": "bob2@bob.bob"},
 | 
			
		||||
                ],
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(200, channel.code, msg=channel.json_body)
 | 
			
		||||
        self.assertEqual("@user:test", channel.json_body["name"])
 | 
			
		||||
        self.assertEqual(2, len(channel.json_body["threepids"]))
 | 
			
		||||
        # result does not always have the same sort order, therefore it becomes sorted
 | 
			
		||||
        sorted_result = sorted(
 | 
			
		||||
            channel.json_body["threepids"], key=lambda k: k["address"]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual("email", sorted_result[0]["medium"])
 | 
			
		||||
        self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
 | 
			
		||||
        self.assertEqual("email", sorted_result[1]["medium"])
 | 
			
		||||
        self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
 | 
			
		||||
        self._check_fields(channel.json_body)
 | 
			
		||||
 | 
			
		||||
        # Set a new and remove a threepid
 | 
			
		||||
        channel = self.make_request(
 | 
			
		||||
            "PUT",
 | 
			
		||||
            self.url_other_user,
 | 
			
		||||
            access_token=self.admin_user_tok,
 | 
			
		||||
            content={
 | 
			
		||||
                "threepids": [
 | 
			
		||||
                    {"medium": "email", "address": "bob2@bob.bob"},
 | 
			
		||||
                    {"medium": "email", "address": "bob3@bob.bob"},
 | 
			
		||||
                ],
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(200, channel.code, msg=channel.json_body)
 | 
			
		||||
        self.assertEqual("@user:test", channel.json_body["name"])
 | 
			
		||||
        self.assertEqual(2, len(channel.json_body["threepids"]))
 | 
			
		||||
        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
 | 
			
		||||
        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
 | 
			
		||||
        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
 | 
			
		||||
        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
 | 
			
		||||
        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
 | 
			
		||||
        self._check_fields(channel.json_body)
 | 
			
		||||
 | 
			
		||||
        # Get user
 | 
			
		||||
        channel = self.make_request(
 | 
			
		||||
@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(200, channel.code, msg=channel.json_body)
 | 
			
		||||
        self.assertEqual("@user:test", channel.json_body["name"])
 | 
			
		||||
        self.assertEqual(2, len(channel.json_body["threepids"]))
 | 
			
		||||
        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
 | 
			
		||||
        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
 | 
			
		||||
        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
 | 
			
		||||
        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
 | 
			
		||||
        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
 | 
			
		||||
        self._check_fields(channel.json_body)
 | 
			
		||||
 | 
			
		||||
        # Remove threepids
 | 
			
		||||
        channel = self.make_request(
 | 
			
		||||
            "PUT",
 | 
			
		||||
            self.url_other_user,
 | 
			
		||||
            access_token=self.admin_user_tok,
 | 
			
		||||
            content={"threepids": []},
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(200, channel.code, msg=channel.json_body)
 | 
			
		||||
        self.assertEqual("@user:test", channel.json_body["name"])
 | 
			
		||||
        self.assertEqual(0, len(channel.json_body["threepids"]))
 | 
			
		||||
        self._check_fields(channel.json_body)
 | 
			
		||||
 | 
			
		||||
    def test_set_external_id(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(200, channel.code, msg=channel.json_body)
 | 
			
		||||
        self.assertEqual("@user:test", channel.json_body["name"])
 | 
			
		||||
        self.assertEqual(2, len(channel.json_body["external_ids"]))
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            channel.json_body["external_ids"],
 | 
			
		||||
            [
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user