Support IPv6-only SMTP servers (#16155)
Use Twisted HostnameEndpoint to connect to SMTP servers (instead of connectTCP/connectSSL) which properly supports IPv6-only servers.
This commit is contained in:
parent
2d72367367
commit
63b51ef3fb
|
@ -0,0 +1 @@
|
|||
Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch).
|
|
@ -23,9 +23,11 @@ from pkg_resources import parse_version
|
|||
|
||||
import twisted
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IOpenSSLContextFactory
|
||||
from twisted.internet.endpoints import HostnameEndpoint
|
||||
from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory
|
||||
from twisted.internet.ssl import optionsForClientTLS
|
||||
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import ISynapseReactor
|
||||
|
@ -97,6 +99,7 @@ async def _sendmail(
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
factory: IProtocolFactory
|
||||
if _is_old_twisted:
|
||||
# before twisted 21.2, we have to override the ESMTPSender protocol to disable
|
||||
# TLS
|
||||
|
@ -110,22 +113,13 @@ async def _sendmail(
|
|||
factory = build_sender_factory(hostname=smtphost if enable_tls else None)
|
||||
|
||||
if force_tls:
|
||||
reactor.connectSSL(
|
||||
smtphost,
|
||||
smtpport,
|
||||
factory,
|
||||
optionsForClientTLS(smtphost),
|
||||
timeout=30,
|
||||
bindAddress=None,
|
||||
)
|
||||
else:
|
||||
reactor.connectTCP(
|
||||
smtphost,
|
||||
smtpport,
|
||||
factory,
|
||||
timeout=30,
|
||||
bindAddress=None,
|
||||
)
|
||||
factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory)
|
||||
|
||||
endpoint = HostnameEndpoint(
|
||||
reactor, smtphost, smtpport, timeout=30, bindAddress=None
|
||||
)
|
||||
|
||||
await make_deferred_yieldable(endpoint.connect(factory))
|
||||
|
||||
await make_deferred_yieldable(d)
|
||||
|
||||
|
|
|
@ -13,19 +13,40 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, List, Tuple, Type, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet._sslverify import ClientTLSOptions
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.defer import ensureDeferred
|
||||
from twisted.internet.interfaces import IProtocolFactory
|
||||
from twisted.internet.ssl import ContextFactory
|
||||
from twisted.mail import interfaces, smtp
|
||||
|
||||
from tests.server import FakeTransport
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
def TestingESMTPTLSClientFactory(
|
||||
contextFactory: ContextFactory,
|
||||
_connectWrapped: bool,
|
||||
wrappedProtocol: IProtocolFactory,
|
||||
) -> IProtocolFactory:
|
||||
"""We use this to pass through in testing without using TLS, but
|
||||
saving the context information to check that it would have happened.
|
||||
|
||||
Note that this is what the MemoryReactor does on connectSSL.
|
||||
It only saves the contextFactory, but starts the connection with the
|
||||
underlying Factory.
|
||||
See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
|
||||
|
||||
wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
|
||||
return wrappedProtocol
|
||||
|
||||
|
||||
@implementer(interfaces.IMessageDelivery)
|
||||
class _DummyMessageDelivery:
|
||||
def __init__(self) -> None:
|
||||
|
@ -75,7 +96,13 @@ class _DummyMessage:
|
|||
pass
|
||||
|
||||
|
||||
class SendEmailHandlerTestCase(HomeserverTestCase):
|
||||
class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
|
||||
ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.reactor.lookups["localhost"] = "127.0.0.1"
|
||||
|
||||
def test_send_email(self) -> None:
|
||||
"""Happy-path test that we can send email to a non-TLS server."""
|
||||
h = self.hs.get_send_email_handler()
|
||||
|
@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
|
||||
0
|
||||
]
|
||||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(host, self.reactor.lookups["localhost"])
|
||||
self.assertEqual(port, 25)
|
||||
|
||||
# wire it up to an SMTP server
|
||||
|
@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
FakeTransport(
|
||||
client_protocol,
|
||||
self.reactor,
|
||||
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
|
||||
peer_address=self.ip_class(
|
||||
"TCP", self.reactor.lookups["localhost"], 1234
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
self.assertEqual(str(user), "foo@bar.com")
|
||||
self.assertIn(b"Subject: test subject", msg)
|
||||
|
||||
@patch(
|
||||
"synapse.handlers.send_email.TLSMemoryBIOFactory",
|
||||
TestingESMTPTLSClientFactory,
|
||||
)
|
||||
@override_config(
|
||||
{
|
||||
"email": {
|
||||
|
@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
)
|
||||
)
|
||||
# there should be an attempt to connect to localhost:465
|
||||
self.assertEqual(len(self.reactor.sslClients), 1)
|
||||
self.assertEqual(len(self.reactor.tcpClients), 1)
|
||||
(
|
||||
host,
|
||||
port,
|
||||
client_factory,
|
||||
contextFactory,
|
||||
_timeout,
|
||||
_bindAddress,
|
||||
) = self.reactor.sslClients[0]
|
||||
self.assertEqual(host, "localhost")
|
||||
) = self.reactor.tcpClients[0]
|
||||
self.assertEqual(host, self.reactor.lookups["localhost"])
|
||||
self.assertEqual(port, 465)
|
||||
# We need to make sure that TLS is happenning
|
||||
self.assertIsInstance(
|
||||
client_factory._wrappedFactory._testingContextFactory,
|
||||
ClientTLSOptions,
|
||||
)
|
||||
# And since we use endpoints, they go through reactor.connectTCP
|
||||
# which works differently to connectSSL on the testing reactor
|
||||
|
||||
# wire it up to an SMTP server
|
||||
message_delivery = _DummyMessageDelivery()
|
||||
|
@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
FakeTransport(
|
||||
client_protocol,
|
||||
self.reactor,
|
||||
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
|
||||
peer_address=self.ip_class(
|
||||
"TCP", self.reactor.lookups["localhost"], 1234
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
|
|||
user, msg = message_delivery.messages.pop()
|
||||
self.assertEqual(str(user), "foo@bar.com")
|
||||
self.assertIn(b"Subject: test subject", msg)
|
||||
|
||||
|
||||
class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
|
||||
ip_class = IPv6Address
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.reactor.lookups["localhost"] = "::1"
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import hashlib
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -45,7 +46,7 @@ import attr
|
|||
from typing_extensions import ParamSpec
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import address, threads, udp
|
||||
from twisted.internet import address, tcp, threads, udp
|
||||
from twisted.internet._resolver import SimpleResolverComplexifier
|
||||
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
||||
from twisted.internet.error import DNSLookupError
|
||||
|
@ -567,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||
conn = super().connectTCP(
|
||||
host, port, factory, timeout=timeout, bindAddress=None
|
||||
)
|
||||
if self.lookups and host in self.lookups:
|
||||
validate_connector(conn, self.lookups[host])
|
||||
|
||||
callback = self._tcp_callbacks.get((host, port))
|
||||
if callback:
|
||||
|
@ -599,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||
super().advance(0)
|
||||
|
||||
|
||||
def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
|
||||
"""Try to validate the obtained connector as it would happen when
|
||||
synapse is running and the conection will be established.
|
||||
|
||||
This method will raise a useful exception when necessary, else it will
|
||||
just do nothing.
|
||||
|
||||
This is in order to help catch quirks related to reactor.connectTCP,
|
||||
since when called directly, the connector's destination will be of type
|
||||
IPv4Address, with the hostname as the literal host that was given (which
|
||||
could be an IPv6-only host or an IPv6 literal).
|
||||
|
||||
But when called from reactor.connectTCP *through* e.g. an Endpoint, the
|
||||
connector's destination will contain the specific IP address with the
|
||||
correct network stack class.
|
||||
|
||||
Note that testing code paths that use connectTCP directly should not be
|
||||
affected by this check, unless they specifically add a test with a
|
||||
matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
|
||||
type ThreadedMemoryReactorClock.
|
||||
For an example of implementing such tests, see test/handlers/send_email.py.
|
||||
"""
|
||||
destination = connector.getDestination()
|
||||
|
||||
# We use address.IPv{4,6}Address to check what the reactor thinks it is
|
||||
# is sending but check for validity with ipaddress.IPv{4,6}Address
|
||||
# because they fail with IPs on the wrong network stack.
|
||||
cls_mapping = {
|
||||
address.IPv4Address: ipaddress.IPv4Address,
|
||||
address.IPv6Address: ipaddress.IPv6Address,
|
||||
}
|
||||
|
||||
cls = cls_mapping.get(destination.__class__)
|
||||
|
||||
if cls is not None:
|
||||
try:
|
||||
cls(expected_ip)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Invalid IP type and resolution for %s. Expected %s to be %s"
|
||||
% (destination, expected_ip, cls.__name__)
|
||||
) from exc
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown address type %s for %s"
|
||||
% (destination.__class__.__name__, destination)
|
||||
)
|
||||
|
||||
|
||||
class ThreadPool:
|
||||
"""
|
||||
Threadless thread pool.
|
||||
|
|
|
@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase):
|
|||
servlets: List of servlet registration function.
|
||||
user_id (str): The user ID to assume if auth is hijacked.
|
||||
hijack_auth: Whether to hijack auth to return the user specified
|
||||
in user_id.
|
||||
in user_id.
|
||||
"""
|
||||
|
||||
hijack_auth: ClassVar[bool] = True
|
||||
|
|
Loading…
Reference in New Issue