Stop using deprecated `keyIds` param on /key/v2/server (#14525)

Fixes #14523.
This commit is contained in:
Richard van der Hoff 2022-11-30 11:59:57 +00:00 committed by GitHub
parent 13aa29db1d
commit ecb6fe9d9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 83 deletions

View File

@ -0,0 +1 @@
Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`.

View File

@ -1 +0,0 @@
Fix a bug introduced in Synapse 0.9 where it would fail to fetch server keys whose IDs contain a forward slash.

View File

@ -0,0 +1 @@
Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`.

View File

@ -14,7 +14,6 @@
import abc import abc
import logging import logging
import urllib
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr import attr
@ -813,31 +812,27 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {} results = {}
async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
server_name = key_to_fetch_item.server_name server_name = key_to_fetch_item.server_name
key_ids = key_to_fetch_item.key_ids
try: try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) keys = await self.get_server_verify_keys_v2_direct(server_name)
results[server_name] = keys results[server_name] = keys
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning("Error looking up keys from %s: %s", server_name, e)
"Error looking up keys %s from %s: %s", key_ids, server_name, e
)
except Exception: except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name) logger.exception("Error getting keys from %s", server_name)
await yieldable_gather_results(get_key, keys_to_fetch) await yieldable_gather_results(get_keys, keys_to_fetch)
return results return results
async def get_server_verify_key_v2_direct( async def get_server_verify_keys_v2_direct(
self, server_name: str, key_ids: Iterable[str] self, server_name: str
) -> Dict[str, FetchKeyResult]: ) -> Dict[str, FetchKeyResult]:
""" """
Args: Args:
server_name: server_name: Server to request keys from
key_ids:
Returns: Returns:
Map from key ID to lookup result Map from key ID to lookup result
@ -845,19 +840,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Raises: Raises:
KeyLookupError if there was a problem making the lookup KeyLookupError if there was a problem making the lookup
""" """
keys: Dict[str, FetchKeyResult] = {}
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
try: try:
response = await self.client.get_json( response = await self.client.get_json(
destination=server_name, destination=server_name,
path="/_matrix/key/v2/server/" path="/_matrix/key/v2/server",
+ urllib.parse.quote(requested_key_id, safe=""),
ignore_backoff=True, ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an # we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's # easy request to handle, so if it doesn't reply within 10s, it's
@ -886,16 +873,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"]) % (server_name, response["server_name"])
) )
response_keys = await self.process_v2_response( return await self.process_v2_response(
from_server=server_name, from_server=server_name,
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )
await self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)
return keys

View File

@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
async def get_json(destination, path, **kwargs): async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME) self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1") self.assertEqual(path, "/_matrix/key/v2/server")
return response return response
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
@ -469,18 +469,6 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {}) self.assertEqual(keys, {})
def test_keyid_containing_forward_slash(self) -> None:
"""We should url-encode any url unsafe chars in key ids.
Detects https://github.com/matrix-org/synapse/issues/14488.
"""
fetcher = ServerKeyFetcher(self.hs)
self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0))
self.http_client.get_json.assert_called_once()
args, kwargs = self.http_client.get_json.call_args
self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato")
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import urllib.parse
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.assertTrue(ignore_backoff) self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name) self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version) key_id = "%s:%s" % (signing_key.alg, signing_key.version)
self.assertEqual( self.assertEqual(path, "/_matrix/key/v2/server")
path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
)
response = { response = {
"server_name": server_name, "server_name": server_name,