Properly typecheck types.http (#14988)

* Tweak http types in Synapse

AFACIS these are correct, and they make mypy happier on tests.http.

* Type hints for test_proxyagent

* type hints for test_srv_resolver

* test_matrix_federation_agent

* tests.http.server._base

* tests.http.__init__

* tests.http.test_additional_resource

* tests.http.test_client

* tests.http.test_endpoint

* tests.http.test_matrixfederationclient

* tests.http.test_servlet

* tests.http.test_simple_client

* tests.http.test_site

* One fixup in tests.server

* Untyped defs

* Changelog

* Fixup syntax for Python 3.7

* Fix olddeps syntax

* Use a twisted IPv4 addr for dummy_address

* Fix typo, thanks Sean

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>

* Remove redundant `Optional`

---------

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
David Robertson 2023-02-07 00:20:04 +00:00 committed by GitHub
parent 5fdc12f482
commit d0fed7a37b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 298 additions and 191 deletions

1
changelog.d/14988.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -32,9 +32,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
|tests/http/federation/test_matrix_federation_agent.py
|tests/http/federation/test_srv_resolver.py
|tests/http/test_proxyagent.py
|tests/module_api/test_api.py
|tests/rest/media/v1/test_media_storage.py
|tests/server.py
@ -92,6 +89,9 @@ disallow_untyped_defs = True
[mypy-tests.handlers.*]
disallow_untyped_defs = True
[mypy-tests.http.*]
disallow_untyped_defs = True
[mypy-tests.logging.*]
disallow_untyped_defs = True

View File

@ -44,6 +44,7 @@ from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IHostResolution,
IReactorCore,
IReactorPluggableNameResolver,
IReactorTime,
IResolutionReceiver,
@ -226,7 +227,9 @@ class _IPBlacklistingResolver:
return recv
@implementer(ISynapseReactor)
# ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer
# of IReactorCore seems to keep mypy-zope happier.
@implementer(IReactorCore, ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP

View File

@ -38,7 +38,6 @@ from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
from synapse.http import redact_uri
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@ -84,7 +83,7 @@ class ProxyAgent(_AgentBase):
def __init__(
self,
reactor: IReactorCore,
proxy_reactor: Optional[ISynapseReactor] = None,
proxy_reactor: Optional[IReactorCore] = None,
contextFactory: Optional[IPolicyForHTTPS] = None,
connectTimeout: Optional[float] = None,
bindAddress: Optional[bytes] = None,

View File

@ -19,13 +19,15 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.ssl import Certificate, trustRootFromCertificates
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
def get_test_https_policy():
def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
"""Get a test IPolicyForHTTPS which trusts the test CA cert
Returns:
@ -39,7 +41,7 @@ def get_test_https_policy():
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
def get_test_ca_cert_file() -> str:
"""Get the path to the test CA cert
The keypair is generated with:
@ -51,7 +53,7 @@ def get_test_ca_cert_file():
return os.path.join(os.path.dirname(__file__), "ca.crt")
def get_test_key_file():
def get_test_key_file() -> str:
"""get the path to the test key
The key file is made with:
@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
def __init__(self, sanlist):
def __init__(self, sanlist: List[bytes]):
"""
Args:
sanlist: list[bytes]: a list of subjectAltName values for the cert
sanlist: a list of subjectAltName values for the cert
"""
self._cert_file = create_test_cert_file(sanlist)
def serverConnectionForTLS(self, tlsProtocol):
def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_certificate_file(self._cert_file)
ctx.use_privatekey_file(get_test_key_file())
return Connection(ctx, None)
# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)

View File

@ -14,7 +14,7 @@
import base64
import logging
import os
from typing import Iterable, Optional
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
from unittest.mock import Mock, patch
import treq
@ -24,14 +24,19 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
from twisted.internet.interfaces import IProtocolFactory
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import _WrappingProtocol
from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from twisted.web.iweb import IPolicyForHTTPS, IResponse
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import FederationPolicyForHTTPS
@ -42,11 +47,21 @@ from synapse.http.federation.well_known_resolver import (
WellKnownResolver,
_cache_period_from_headers,
)
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
LoggingContextOrSentinel,
current_context,
)
from synapse.types import ISynapseReactor
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_ca_cert_file,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import default_config
@ -54,15 +69,17 @@ logger = logging.getLogger(__name__)
# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(result):
async def resolve_service(_):
def generate_resolve_service(
result: List[Server],
) -> Callable[[Any], Awaitable[List[Server]]]:
async def resolve_service(_: Any) -> List[Server]:
return result
return resolve_service
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.tls_factory = FederationPolicyForHTTPS(config)
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache(
"test_cache", timer=self.reactor.seconds
)
self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache(
"test_cache", timer=self.reactor.seconds
)
self.well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
self,
client_factory: IProtocolFactory,
ssl: bool = True,
expected_sni: bytes = None,
tls_sanlist: Optional[Iterable[bytes]] = None,
expected_sni: Optional[bytes] = None,
tls_sanlist: Optional[List[bytes]] = None,
) -> HTTPChannel:
"""Builds a test server, and completes the outgoing client connection
Args:
@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_protocol = server_factory.buildProtocol(None)
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
@ -125,7 +146,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol = client_factory.buildProtocol(dummy_address)
assert isinstance(client_protocol, _WrappingProtocol)
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@ -136,6 +158,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
)
if ssl:
assert isinstance(server_protocol, TLSMemoryBIOProtocol)
# fish the test server back out of the server-side TLS protocol.
http_protocol = server_protocol.wrappedProtocol
# grab a hold of the TLS connection, in case it gets torn down
@ -144,6 +167,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
http_protocol = server_protocol
tls_connection = None
assert isinstance(http_protocol, HTTPChannel)
# give the reactor a pump to get the TLS juices flowing (if needed)
self.reactor.advance(0)
@ -159,12 +183,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri: bytes):
def _make_get_request(
self, uri: bytes
) -> Generator["Deferred[object]", object, IResponse]:
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
with LoggingContext("one") as context:
fetch_d = self.agent.request(b"GET", uri)
fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri)
# Nothing happened yet
self.assertNoResult(fetch_d)
@ -172,8 +198,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# should have reset logcontext to the sentinel
_check_logcontext(SENTINEL_CONTEXT)
fetch_res: IResponse
try:
fetch_res = yield fetch_d
fetch_res = yield fetch_d # type: ignore[misc, assignment]
return fetch_res
except Exception as e:
logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
@ -216,7 +243,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
request: Request,
content: bytes,
headers: Optional[dict] = None,
):
) -> None:
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@ -237,16 +264,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
because it is created too early during setUp
"""
return MatrixFederationAgent(
reactor=self.reactor,
reactor=cast(ISynapseReactor, self.reactor),
tls_client_options_factory=self.tls_factory,
user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
def test_get(self):
def test_get(self) -> None:
"""happy-path test of a GET request with an explicit port"""
self._do_get()
@ -254,11 +281,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "proxy.com", "no_proxy": "testserv"},
)
def test_get_bypass_proxy(self):
def test_get_bypass_proxy(self) -> None:
"""test of a GET request with an explicit port and bypass proxy"""
self._do_get()
def _do_get(self):
def _do_get(self) -> None:
"""test of a GET request with an explicit port"""
self.agent = self._make_agent()
@ -318,7 +345,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
@patch.dict(
os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
)
def test_get_via_http_proxy(self):
def test_get_via_http_proxy(self) -> None:
"""test for federation request through a http proxy"""
self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
@ -326,7 +353,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
)
def test_get_via_http_proxy_with_auth(self):
def test_get_via_http_proxy_with_auth(self) -> None:
"""test for federation request through a http proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
@ -335,7 +362,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
def test_get_via_https_proxy(self):
def test_get_via_https_proxy(self) -> None:
"""test for federation request through a https proxy"""
self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
@ -343,7 +370,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
os.environ,
{"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
)
def test_get_via_https_proxy_with_auth(self):
def test_get_via_https_proxy_with_auth(self) -> None:
"""test for federation request through a https proxy with authentication"""
self._do_get_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
@ -353,7 +380,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
):
) -> None:
"""Send a https federation request via an agent and check that it is correctly
received at the proxy and client. The proxy can use either http or https.
Args:
@ -418,10 +445,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(None)
).buildProtocol(dummy_address)
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
assert proxy_server_transport is not None
server_ssl_protocol.makeConnection(proxy_server_transport)
# ... and replace the protocol on the proxy's transport with the
@ -451,6 +480,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# now there should be a pending request
http_server = server_ssl_protocol.wrappedProtocol
assert isinstance(http_server, HTTPChannel)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
@ -491,7 +521,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
def test_get_ip_address(self):
def test_get_ip_address(self) -> None:
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
@ -526,7 +556,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_ipv6_address(self):
def test_get_ipv6_address(self) -> None:
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with no port)
@ -562,7 +592,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_ipv6_address_with_port(self):
def test_get_ipv6_address_with_port(self) -> None:
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with explicit port)
@ -598,7 +628,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_bad_cert(self):
def test_get_hostname_bad_cert(self) -> None:
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
@ -651,7 +681,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
failure_reason = e.value.reasons[0]
self.assertIsInstance(failure_reason.value, VerificationError)
def test_get_ip_address_bad_cert(self):
def test_get_ip_address_bad_cert(self) -> None:
"""
Test the behaviour when the server name contains an explicit IP, but
the server cert doesn't cover it
@ -684,7 +714,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
failure_reason = e.value.reasons[0]
self.assertIsInstance(failure_reason.value, VerificationError)
def test_get_no_srv_no_well_known(self):
def test_get_no_srv_no_well_known(self) -> None:
"""
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
@ -740,7 +770,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_well_known(self):
def test_get_well_known(self) -> None:
"""Test the behaviour when the .well-known delegates elsewhere"""
self.agent = self._make_agent()
@ -802,7 +832,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
def test_get_well_known_redirect(self):
def test_get_well_known_redirect(self) -> None:
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
@ -892,7 +922,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
def test_get_invalid_well_known(self):
def test_get_invalid_well_known(self) -> None:
"""
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
@ -945,7 +975,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_well_known_unsigned_cert(self):
def test_get_well_known_unsigned_cert(self) -> None:
"""Test the behaviour when the .well-known server presents a cert
not signed by a CA
"""
@ -969,7 +999,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
cast(ISynapseReactor, self.reactor),
Agent(self.reactor, contextFactory=tls_factory),
b"test-agent",
well_known_cache=self.well_known_cache,
@ -999,7 +1029,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
b"_matrix._tcp.testserv"
)
def test_get_hostname_srv(self):
def test_get_hostname_srv(self) -> None:
"""
Test the behaviour when there is a single SRV record
"""
@ -1041,7 +1071,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_well_known_srv(self):
def test_get_well_known_srv(self) -> None:
"""Test the behaviour when the .well-known redirects to a place where there
is a SRV.
"""
@ -1101,7 +1131,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_idna_servername(self):
def test_idna_servername(self) -> None:
"""test the behaviour when the server name has idna chars in"""
self.agent = self._make_agent()
@ -1163,7 +1193,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_idna_srv_target(self):
def test_idna_srv_target(self) -> None:
"""test the behaviour when the target of a SRV record has idna chars"""
self.agent = self._make_agent()
@ -1206,7 +1236,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_well_known_cache(self):
def test_well_known_cache(self) -> None:
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = defer.ensureDeferred(
@ -1262,7 +1292,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"other-server")
def test_well_known_cache_with_temp_failure(self):
def test_well_known_cache_with_temp_failure(self) -> None:
"""Test that we refetch well-known before the cache expires, and that
it ignores transient errors.
"""
@ -1341,7 +1371,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
def test_well_known_too_large(self):
def test_well_known_too_large(self) -> None:
"""A well-known query that returns a result which is too large should be rejected."""
self.reactor.lookups["testserv"] = "1.2.3.4"
@ -1367,7 +1397,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
r = self.successResultOf(fetch_d)
self.assertIsNone(r.delegated_server)
def test_srv_fallbacks(self):
def test_srv_fallbacks(self) -> None:
"""Test that other SRV results are tried if the first one fails."""
self.agent = self._make_agent()
@ -1427,7 +1457,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self):
def test_cache_control(self) -> None:
# uppercase
self.assertEqual(
_cache_period_from_headers(
@ -1464,7 +1494,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
0,
)
def test_expires(self):
def test_expires(self) -> None:
self.assertEqual(
_cache_period_from_headers(
Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
@ -1491,14 +1521,14 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
def _check_logcontext(context):
def _check_logcontext(context: LoggingContextOrSentinel) -> None:
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Iterable[bytes] = None
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> IProtocolFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
@ -1537,7 +1567,7 @@ def _get_test_protocol_factory() -> IProtocolFactory:
return server_factory
def _log_request(request: str):
def _log_request(request: str) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info(f"Completed request {request}")
@ -1547,6 +1577,8 @@ class TrustingTLSPolicyForHTTPS:
"""An IPolicyForHTTPS which checks that the certificate belongs to the
right server, but doesn't check the certificate chain."""
def creatorForNetloc(self, hostname, port):
def creatorForNetloc(
self, hostname: bytes, port: int
) -> IOpenSSLClientConnectionCreator:
certificateOptions = OpenSSLCertificateOptions()
return ClientTLSOptions(hostname, certificateOptions.getContext())

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Generator, List, Tuple, cast
from unittest.mock import Mock
from twisted.internet import defer
@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.logging.context import LoggingContext, current_context
from tests import unittest
@ -28,7 +28,7 @@ from tests.utils import MockClock
class SrvResolverTestCase(unittest.TestCase):
def test_resolve(self):
def test_resolve(self) -> None:
dns_client_mock = Mock()
service_name = b"test_service.example.com"
@ -38,18 +38,19 @@ class SrvResolverTestCase(unittest.TestCase):
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
)
result_deferred = Deferred()
result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock.lookupService.return_value = result_deferred
cache = {}
cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@defer.inlineCallbacks
def do_lookup():
def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
result = yield defer.ensureDeferred(resolve_d)
result: List[Server]
result = yield defer.ensureDeferred(resolve_d) # type: ignore[assignment]
# should have restored our context
self.assertIs(current_context(), ctx)
@ -70,7 +71,9 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers[0].host, host_name)
@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
def test_from_cache_expired_and_dns_fail(
self,
) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
@ -81,10 +84,13 @@ class SrvResolverTestCase(unittest.TestCase):
entry.priority = 0
entry.weight = 0
cache = {service_name: [entry]}
cache = {service_name: [cast(Server, entry)]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
servers: List[Server]
servers = yield defer.ensureDeferred(
resolver.resolve_service(service_name)
) # type: ignore[assignment]
dns_client_mock.lookupService.assert_called_once_with(service_name)
@ -92,7 +98,7 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers, cache[service_name])
@defer.inlineCallbacks
def test_from_cache(self):
def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
clock = MockClock()
dns_client_mock = Mock(spec_set=["lookupService"])
@ -105,12 +111,15 @@ class SrvResolverTestCase(unittest.TestCase):
entry.priority = 0
entry.weight = 0
cache = {service_name: [entry]}
cache = {service_name: [cast(Server, entry)]}
resolver = SrvResolver(
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
servers: List[Server]
servers = yield defer.ensureDeferred(
resolver.resolve_service(service_name)
) # type: ignore[assignment]
self.assertFalse(dns_client_mock.lookupService.called)
@ -118,45 +127,48 @@ class SrvResolverTestCase(unittest.TestCase):
self.assertEqual(servers, cache[service_name])
@defer.inlineCallbacks
def test_empty_cache(self):
def test_empty_cache(self) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = b"test_service.example.com"
cache = {}
cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
def test_name_error(self) -> Generator["Deferred[object]", object, None]:
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = b"test_service.example.com"
cache = {}
cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
servers: List[Server]
servers = yield defer.ensureDeferred(
resolver.resolve_service(service_name)
) # type: ignore[assignment]
self.assertEqual(len(servers), 0)
self.assertEqual(len(cache), 0)
def test_disabled_service(self):
def test_disabled_service(self) -> None:
"""
test the behaviour when there is a single record which is ".".
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
# Old versions of Twisted don't have an ensureDeferred in failureResultOf.
@ -173,16 +185,16 @@ class SrvResolverTestCase(unittest.TestCase):
self.failureResultOf(resolve_d, ConnectError)
def test_non_srv_answer(self):
def test_non_srv_answer(self) -> None:
"""
test the behaviour when the dns server gives us a spurious non-SRV response
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
cache: Dict[bytes, List[Server]] = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
# Old versions of Twisted don't have an ensureDeferred in successResultOf.

View File

@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
return method_name
def _hash_stack(stack: List[inspect.FrameInfo]):
def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]:
"""Turns a stack into a hashable value that can be put into a set."""
return tuple(_format_stack_frame(frame) for frame in stack)

View File

@ -11,28 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from twisted.web.server import Request
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from tests.server import FakeSite, make_request
from tests.unittest import HomeserverTestCase
class _AsyncTestCustomEndpoint:
def __init__(self, config, module_api):
def __init__(self, config: JsonDict, module_api: Any) -> None:
pass
async def handle_request(self, request):
async def handle_request(self, request: Request) -> None:
assert isinstance(request, SynapseRequest)
respond_with_json(request, 200, {"some_key": "some_value_async"})
class _SyncTestCustomEndpoint:
def __init__(self, config, module_api):
def __init__(self, config: JsonDict, module_api: Any) -> None:
pass
async def handle_request(self, request):
async def handle_request(self, request: Request) -> None:
assert isinstance(request, SynapseRequest)
respond_with_json(request, 200, {"some_key": "some_value_sync"})
@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase):
and async handlers.
"""
def test_async(self):
def test_async(self) -> None:
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
def test_sync(self) -> None:
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)

View File

@ -13,10 +13,12 @@
# limitations under the License.
from io import BytesIO
from typing import Tuple, Union
from unittest.mock import Mock
from netaddr import IPSet
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
@ -28,6 +30,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
_DiscardBodyWithMaxSizeProtocol,
read_body_with_max_size,
)
@ -36,7 +39,9 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
def _build_response(self, length=UNKNOWN_LENGTH):
def _build_response(
self, length: Union[int, str] = UNKNOWN_LENGTH
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
return result, deferred, protocol
def _assert_error(self, deferred, protocol):
def _assert_error(
self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
) -> None:
"""Ensure that the expected error is received."""
self.assertIsInstance(deferred.result, Failure)
assert isinstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
protocol.transport.abortConnection.assert_called_once()
assert protocol.transport is not None
# type-ignore: presumably abortConnection has been replaced with a Mock.
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
def _cleanup_error(self, deferred):
def _cleanup_error(self, deferred: "Deferred[int]") -> None:
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f):
def errback(f: Failure) -> None:
called[0] = True
deferred.addErrback(errback)
self.assertTrue(called[0])
def test_no_error(self):
def test_no_error(self) -> None:
"""A response that is NOT too large."""
result, deferred, protocol = self._build_response()
@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"12345")
self.assertEqual(deferred.result, 5)
def test_too_large(self):
def test_too_large(self) -> None:
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_response()
@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_multiple_packets(self):
def test_multiple_packets(self) -> None:
"""Data should be accumulated through mutliple packets."""
result, deferred, protocol = self._build_response()
@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"1234")
self.assertEqual(deferred.result, 4)
def test_additional_data(self):
def test_additional_data(self) -> None:
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_response()
@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_content_length(self):
def test_content_length(self) -> None:
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_response(length=10)
@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
class BlacklistingAgentTest(TestCase):
def setUp(self):
def setUp(self) -> None:
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
def test_reactor(self):
def test_reactor(self) -> None:
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase):
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
def test_agent(self):
def test_agent(self) -> None:
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),

View File

@ -17,7 +17,7 @@ from tests import unittest
class ServerNameTestCase(unittest.TestCase):
def test_parse_server_name(self):
def test_parse_server_name(self) -> None:
test_data = {
"localhost": ("localhost", None),
"my-example.com:1234": ("my-example.com", 1234),
@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase):
for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o)
def test_validate_bad_server_names(self):
def test_validate_bad_server_names(self) -> None:
test_data = [
"", # empty
"localhost:http", # non-numeric port

View File

@ -11,16 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Generator
from unittest.mock import Mock
from netaddr import IPSet
from parameterized import parameterized
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.defer import Deferred, TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import MemoryReactor, StringTransport
from twisted.web.client import ResponseNeverReceived
from twisted.web.http import HTTPChannel
@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
from synapse.logging.context import (
SENTINEL_CONTEXT,
LoggingContext,
LoggingContextOrSentinel,
current_context,
)
from synapse.server import HomeServer
from synapse.util import Clock
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
def check_logcontext(context):
def check_logcontext(context: LoggingContextOrSentinel) -> None:
current = current_context()
if current is not context:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.cl = MatrixFederationHttpClient(self.hs, None)
self.reactor.lookups["testserv"] = "1.2.3.4"
def test_client_get(self):
def test_client_get(self) -> None:
"""
happy-path test of a GET request
"""
@defer.inlineCallbacks
def do_request():
def do_request() -> Generator["Deferred[object]", object, object]:
with LoggingContext("one") as context:
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase):
# check the response is as expected
self.assertEqual(res, {"a": 1})
def test_dns_error(self):
def test_dns_error(self) -> None:
"""
If the DNS lookup returns an error, it will bubble up.
"""
@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
def test_client_connection_refused(self) -> None:
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
)
@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIs(f.value.inner_exception, e)
def test_client_never_connect(self):
def test_client_never_connect(self) -> None:
"""
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase):
f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
)
def test_client_connect_no_response(self):
def test_client_connect_no_response(self) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
def test_client_ip_range_blacklist(self):
def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Set up the ip_range blacklist
@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(d, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
def test_client_gets_headers(self):
def test_client_gets_headers(self) -> None:
"""
Once the client gets the headers, _request returns successfully.
"""
@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(r.code, 200)
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
def test_timeout_reading_body(self, method_name: str):
def test_timeout_reading_body(self, method_name: str) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a RequestSendFailed with can_retry.
@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertTrue(f.value.can_retry)
self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
def test_client_requires_trailing_slashes(self):
def test_client_requires_trailing_slashes(self) -> None:
"""
If a connection is made to a client but the client rejects it due to
requiring a trailing slash. We need to retry the request with a
@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r, {})
def test_client_does_not_retry_on_400_plus(self):
def test_client_does_not_retry_on_400_plus(self) -> None:
"""
Another test for trailing slashes but now test that we don't retry on
trailing slashes on a non-400/M_UNRECOGNIZED response.
@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase):
# We should get a 404 failure response
self.failureResultOf(d)
def test_client_sends_body(self):
def test_client_sends_body(self) -> None:
defer.ensureDeferred(
self.cl.post_json(
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase):
content = request.content.read()
self.assertEqual(content, b'{"a":"b"}')
def test_closes_connection(self):
def test_closes_connection(self) -> None:
"""Check that the client closes unused HTTP connections"""
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertTrue(conn.disconnecting)
@parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
def test_json_error(self, return_value):
def test_json_error(self, return_value: bytes) -> None:
"""
Test what happens if invalid JSON is returned from the remote endpoint.
"""
@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
def test_too_big(self):
def test_too_big(self) -> None:
"""
Test what happens if a huge response is returned from the remote endpoint.
"""

View File

@ -14,7 +14,7 @@
import base64
import logging
import os
from typing import Iterable, Optional
from typing import List, Optional
from unittest.mock import patch
import treq
@ -22,7 +22,11 @@ from netaddr import IPSet
from parameterized import parameterized
from twisted.internet import interfaces # noqa: F401
from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
from twisted.internet.endpoints import (
HostnameEndpoint,
_WrapperEndpoint,
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
@ -32,7 +36,11 @@ from synapse.http.client import BlacklistingReactorWrapper
from synapse.http.connectproxyclient import ProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_https_policy,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
@ -183,7 +191,7 @@ class ProxyParserTests(TestCase):
expected_hostname: bytes,
expected_port: int,
expected_credentials: Optional[bytes],
):
) -> None:
"""
Tests that a given proxy URL will be broken into the components.
Args:
@ -209,7 +217,7 @@ class ProxyParserTests(TestCase):
class MatrixFederationAgentTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
def _make_connection(
@ -218,7 +226,7 @@ class MatrixFederationAgentTests(TestCase):
server_factory: IProtocolFactory,
ssl: bool = False,
expected_sni: Optional[bytes] = None,
tls_sanlist: Optional[Iterable[bytes]] = None,
tls_sanlist: Optional[List[bytes]] = None,
) -> IProtocol:
"""Builds a test server, and completes the outgoing client connection
@ -244,7 +252,8 @@ class MatrixFederationAgentTests(TestCase):
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_protocol = server_factory.buildProtocol(None)
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
# now, tell the client protocol factory to build the client protocol,
# and wire the output of said protocol up to the server via
@ -252,7 +261,8 @@ class MatrixFederationAgentTests(TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol = client_factory.buildProtocol(dummy_address)
assert client_protocol is not None
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@ -263,6 +273,7 @@ class MatrixFederationAgentTests(TestCase):
)
if ssl:
assert isinstance(server_protocol, TLSMemoryBIOProtocol)
http_protocol = server_protocol.wrappedProtocol
tls_connection = server_protocol._tlsConnection
else:
@ -288,7 +299,7 @@ class MatrixFederationAgentTests(TestCase):
scheme: bytes,
hostname: bytes,
path: bytes,
):
) -> None:
"""Runs a test case for a direct connection not going through a proxy.
Args:
@ -319,6 +330,7 @@ class MatrixFederationAgentTests(TestCase):
ssl=is_https,
expected_sni=hostname if is_https else None,
)
assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@ -339,34 +351,34 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
def test_http_request(self):
def test_http_request(self) -> None:
agent = ProxyAgent(self.reactor)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
def test_https_request(self):
def test_https_request(self) -> None:
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
def test_http_request_use_proxy_empty_environment(self):
def test_http_request_use_proxy_empty_environment(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
def test_http_request_via_uppercase_no_proxy(self):
def test_http_request_via_uppercase_no_proxy(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
)
def test_http_request_via_no_proxy(self):
def test_http_request_via_no_proxy(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
)
def test_https_request_via_no_proxy(self):
def test_https_request_via_no_proxy(self) -> None:
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@ -375,12 +387,12 @@ class MatrixFederationAgentTests(TestCase):
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
def test_http_request_via_no_proxy_star(self):
def test_http_request_via_no_proxy_star(self) -> None:
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
def test_https_request_via_no_proxy_star(self):
def test_https_request_via_no_proxy_star(self) -> None:
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@ -389,7 +401,7 @@ class MatrixFederationAgentTests(TestCase):
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self):
def test_http_request_via_proxy(self) -> None:
"""
Tests that requests can be made through a proxy.
"""
@ -401,7 +413,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
)
def test_http_request_via_proxy_with_auth(self):
def test_http_request_via_proxy_with_auth(self) -> None:
"""
Tests that authenticated requests can be made through a proxy.
"""
@ -412,7 +424,7 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
)
def test_http_request_via_https_proxy(self):
def test_http_request_via_https_proxy(self) -> None:
self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
)
@ -424,13 +436,13 @@ class MatrixFederationAgentTests(TestCase):
"no_proxy": "unused.com",
},
)
def test_http_request_via_https_proxy_with_auth(self):
def test_http_request_via_https_proxy_with_auth(self) -> None:
self._do_http_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
)
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
def test_https_request_via_proxy(self) -> None:
"""Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=None
@ -440,7 +452,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
)
def test_https_request_via_proxy_with_auth(self):
def test_https_request_via_proxy_with_auth(self) -> None:
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
@ -449,7 +461,7 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
def test_https_request_via_https_proxy(self):
def test_https_request_via_https_proxy(self) -> None:
"""Tests that TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=None
@ -459,7 +471,7 @@ class MatrixFederationAgentTests(TestCase):
os.environ,
{"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
)
def test_https_request_via_https_proxy_with_auth(self):
def test_https_request_via_https_proxy_with_auth(self) -> None:
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
self._do_https_request_via_proxy(
expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
@ -469,7 +481,7 @@ class MatrixFederationAgentTests(TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
):
) -> None:
"""Send a http request via an agent and check that it is correctly received at
the proxy. The proxy can use either http or https.
Args:
@ -501,6 +513,7 @@ class MatrixFederationAgentTests(TestCase):
tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
expected_sni=b"proxy.com" if expect_proxy_ssl else None,
)
assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@ -542,7 +555,7 @@ class MatrixFederationAgentTests(TestCase):
self,
expect_proxy_ssl: bool = False,
expected_auth_credentials: Optional[bytes] = None,
):
) -> None:
"""Send a https request via an agent and check that it is correctly received at
the proxy and client. The proxy can use either http or https.
Args:
@ -606,10 +619,12 @@ class MatrixFederationAgentTests(TestCase):
# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(None)
).buildProtocol(dummy_address)
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
assert proxy_server_transport is not None
server_ssl_protocol.makeConnection(proxy_server_transport)
# ... and replace the protocol on the proxy's transport with the
@ -644,6 +659,7 @@ class MatrixFederationAgentTests(TestCase):
# now there should be a pending request
http_server = server_ssl_protocol.wrappedProtocol
assert isinstance(http_server, HTTPChannel)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
@ -667,7 +683,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_http_request_via_proxy_with_blacklist(self):
def test_http_request_via_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@ -691,6 +707,7 @@ class MatrixFederationAgentTests(TestCase):
http_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
assert isinstance(http_server, HTTPChannel)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@ -712,7 +729,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
def test_https_request_via_uppercase_proxy_with_blacklist(self):
def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@ -737,11 +754,15 @@ class MatrixFederationAgentTests(TestCase):
proxy_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
s2c_transport = proxy_server.transport
assert isinstance(s2c_transport, FakeTransport)
client_protocol = s2c_transport.other
assert isinstance(client_protocol, _WrappingProtocol)
c2s_transport = client_protocol.transport
assert isinstance(c2s_transport, FakeTransport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@ -762,8 +783,10 @@ class MatrixFederationAgentTests(TestCase):
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
ssl_protocol = ssl_factory.buildProtocol(None)
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
assert isinstance(http_server, HTTPChannel)
ssl_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, ssl_protocol)
@ -797,28 +820,28 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_proxy_with_no_scheme(self):
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
def test_proxy_with_unsupported_scheme(self):
def test_proxy_with_unsupported_scheme(self) -> None:
with self.assertRaises(ValueError):
ProxyAgent(self.reactor, use_proxy=True)
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
def test_proxy_with_http_scheme(self):
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self):
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
self.assertEqual(
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
)
@ -828,7 +851,7 @@ class MatrixFederationAgentTests(TestCase):
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Iterable[bytes] = None
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> IProtocolFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
@ -865,6 +888,6 @@ def _get_test_protocol_factory() -> IProtocolFactory:
return server_factory
def _log_request(request: str):
def _log_request(request: str) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info(f"Completed request {request}")

View File

@ -14,7 +14,7 @@
import json
from http import HTTPStatus
from io import BytesIO
from typing import Tuple
from typing import Tuple, Union
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@ -33,7 +33,7 @@ from tests import unittest
from tests.http.server._base import test_disconnect
def make_request(content):
def make_request(content: Union[bytes, JsonDict]) -> Mock:
"""Make an object that acts enough like a request."""
request = Mock(spec=["method", "uri", "content"])
@ -47,7 +47,7 @@ def make_request(content):
class TestServletUtils(unittest.TestCase):
def test_parse_json_value(self):
def test_parse_json_value(self) -> None:
"""Basic tests for parse_json_value_from_request."""
# Test round-tripping.
obj = {"foo": 1}
@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase):
with self.assertRaises(SynapseError):
parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
def test_parse_json_object(self):
def test_parse_json_object(self) -> None:
"""Basic tests for parse_json_object_from_request."""
# Test empty.
result = parse_json_object_from_request(

View File

@ -17,22 +17,24 @@ from netaddr import IPSet
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import MemoryReactor
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class SimpleHttpClientTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs: "HomeServer"):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None:
# Add a DNS entry for a test server
self.reactor.lookups["testserv"] = "1.2.3.4"
self.cl = hs.get_simple_http_client()
def test_dns_error(self):
def test_dns_error(self) -> None:
"""
If the DNS lookup returns an error, it will bubble up.
"""
@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
f = self.failureResultOf(d)
self.assertIsInstance(f.value, DNSLookupError)
def test_client_connection_refused(self):
def test_client_connection_refused(self) -> None:
d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
self.pump()
@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIs(f.value, e)
def test_client_never_connect(self):
def test_client_never_connect(self) -> None:
"""
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
def test_client_connect_no_response(self):
def test_client_connect_no_response(self) -> None:
"""
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestTimedOutError)
def test_client_ip_range_blacklist(self):
def test_client_ip_range_blacklist(self) -> None:
"""Ensure that Synapse does not try to connect to blacklisted IPs"""
# Add some DNS entries we'll blacklist

View File

@ -13,18 +13,20 @@
# limitations under the License.
from twisted.internet.address import IPv6Address
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import MemoryReactor, StringTransport
from synapse.app.homeserver import SynapseHomeServer
from synapse.server import HomeServer
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class SynapseRequestTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
def test_large_request(self):
def test_large_request(self) -> None:
"""overlarge HTTP requests should be rejected"""
self.hs.start_listening()

View File

@ -70,7 +70,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests.utils import (
@ -401,7 +401,9 @@ def make_request(
return channel
@implementer(IReactorPluggableNameResolver)
# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
# marking this as an implementer of the latter seems to keep mypy-zope happier.
@implementer(IReactorPluggableNameResolver, ISynapseReactor)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.