Feature to allow setting of custom CORS response headers

This commit is contained in:
Hugh Nimmo-Smith 2023-09-20 22:04:31 +01:00
parent 7ec0a141b4
commit 5fa409b078
6 changed files with 93 additions and 28 deletions

View File

@ -464,6 +464,13 @@ See the docs [request log format](../administration/request_log.md).
* `additional_resources`: Only valid for an 'http' listener. A map of * `additional_resources`: Only valid for an 'http' listener. A map of
additional endpoints which should be loaded via dynamic modules. additional endpoints which should be loaded via dynamic modules.
* `cors_response_headers`: Only valid for an 'http' listener. A map of Cross-Origin Resource Sharing
headers to use in place of the default ones. You could choose to do this using a
[reverse-proxy](../../reverse_proxy.md) instead.
_Added in Synapse 1.94.0._
http://localhost:8008
Unix socket support (_Added in Synapse 1.89.0_): Unix socket support (_Added in Synapse 1.89.0_):
* `path`: A path and filename for a Unix socket. Make sure it is located in a * `path`: A path and filename for a Unix socket. Make sure it is located in a
directory with read and write permissions, and that it already exists (the directory directory with read and write permissions, and that it already exists (the directory

View File

@ -211,6 +211,7 @@ class HttpListenerConfig:
# If true, the listener will return CORS response headers compatible with MSC3886: # If true, the listener will return CORS response headers compatible with MSC3886:
# https://github.com/matrix-org/matrix-spec-proposals/pull/3886 # https://github.com/matrix-org/matrix-spec-proposals/pull/3886
experimental_cors_msc3886: bool = False experimental_cors_msc3886: bool = False
cors_response_headers: Optional[Dict[str, dict]] = None
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -989,6 +990,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
tag=listener.get("tag"), tag=listener.get("tag"),
request_id_header=listener.get("request_id_header"), request_id_header=listener.get("request_id_header"),
experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False),
cors_response_headers=listener.get("cors_response_headers"),
) )
if socket_path: if socket_path:

View File

@ -899,28 +899,8 @@ def set_cors_headers(request: "SynapseRequest") -> None:
Args: Args:
request: The http request to add CORS to. request: The http request to add CORS to.
""" """
request.setHeader(b"Access-Control-Allow-Origin", b"*") for k, v in request.cors_response_headers.items():
request.setHeader( request.setHeader(k, v)
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
)
if request.experimental_cors_msc3886:
request.setHeader(
b"Access-Control-Allow-Headers",
b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match",
)
request.setHeader(
b"Access-Control-Expose-Headers",
b"ETag, Location, X-Max-Bytes",
)
else:
request.setHeader(
b"Access-Control-Allow-Headers",
b"X-Requested-With, Content-Type, Authorization, Date",
)
request.setHeader(
b"Access-Control-Expose-Headers",
b"Synapse-Trace-Id, Server",
)
def set_corp_headers(request: Request) -> None: def set_corp_headers(request: Request) -> None:

View File

@ -48,6 +48,20 @@ logger = logging.getLogger(__name__)
_next_request_seq = 0 _next_request_seq = 0
DEFAULT_CORS_HEADERS: dict[bytes, bytes] = {
b"Access-Control-Allow-Origin": b"*",
b"Access-Control-Allow-Methods": b"GET, HEAD, POST, PUT, DELETE, OPTIONS",
b"Access-Control-Allow-Headers": b"X-Requested-With, Content-Type, Authorization, Date",
b"Access-Control-Expose-Headers": b"Synapse-Trace-Id, Server",
}
EXPERIMENTAL_MSC3886_CORS_HEADERS: dict[bytes, bytes] = {
b"Access-Control-Allow-Origin": b"*",
b"Access-Control-Allow-Methods": b"GET, HEAD, POST, PUT, DELETE, OPTIONS",
b"Access-Control-Allow-Headers": b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match",
b"Access-Control-Expose-Headers": b"ETag, Location, X-Max-Bytes",
}
class SynapseRequest(Request): class SynapseRequest(Request):
"""Class which encapsulates an HTTP request to synapse. """Class which encapsulates an HTTP request to synapse.
@ -87,8 +101,7 @@ class SynapseRequest(Request):
self.reactor = site.reactor self.reactor = site.reactor
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
self.experimental_cors_msc3886 = site.experimental_cors_msc3886 self.cors_response_headers = site.cors_response_headers
# The requester, if authenticated. For federation requests this is the # The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object. # server name, for client requests this is the Requester object.
self._requester: Optional[Union[Requester, str]] = None self._requester: Optional[Union[Requester, str]] = None
@ -658,8 +671,14 @@ class SynapseSite(ProxySite):
request_id_header = config.http_options.request_id_header request_id_header = config.http_options.request_id_header
self.experimental_cors_msc3886: bool = ( # Use custom CORS headers if given
config.http_options.experimental_cors_msc3886 self.cors_response_headers = config.http_options.cors_response_headers
# Otherwise, use the default CORS headers
if self.cors_response_headers is None:
self.cors_response_headers = (
EXPERIMENTAL_MSC3886_CORS_HEADERS
if config.http_options.experimental_cors_msc3886
else DEFAULT_CORS_HEADERS
) )
def request_factory(channel: HTTPChannel, queued: bool) -> Request: def request_factory(channel: HTTPChannel, queued: bool) -> Request:

View File

@ -326,6 +326,7 @@ class FakeSite:
self._resource = resource self._resource = resource
self.reactor = reactor self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886 self.experimental_cors_msc3886 = experimental_cors_msc3886
self.cors_response_headers = {}
def getResourceFor(self, request: Request) -> IResource: def getResourceFor(self, request: Request) -> IResource:
return self._resource return self._resource

View File

@ -228,10 +228,16 @@ class OptionsResourceTests(unittest.TestCase):
self.resource.putChild(b"res", DummyResource()) self.resource.putChild(b"res", DummyResource())
def _make_request( def _make_request(
self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False self,
method: bytes,
path: bytes,
experimental_cors_msc3886: bool = False,
cors_response_headers: dict[str, str] = None,
) -> FakeChannel: ) -> FakeChannel:
"""Create a request from the method/path and return a channel with the response.""" """Create a request from the method/path and return a channel with the response."""
# Create a site and query for the resource. # Create a site and query for the resource.
if cors_response_headers is None:
cors_response_headers = {}
site = SynapseSite( site = SynapseSite(
"test", "test",
"site_tag", "site_tag",
@ -241,6 +247,7 @@ class OptionsResourceTests(unittest.TestCase):
"type": "http", "type": "http",
"port": 0, "port": 0,
"experimental_cors_msc3886": experimental_cors_msc3886, "experimental_cors_msc3886": experimental_cors_msc3886,
"cors_response_headers": cors_response_headers,
}, },
), ),
self.resource, self.resource,
@ -340,6 +347,55 @@ class OptionsResourceTests(unittest.TestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"/res/") self.assertEqual(channel.result["body"], b"/res/")
def test_custom_cors(self) -> None:
"""An OPTIONS request to a known URL should return customised CORS headers."""
channel = self._make_request(
b"OPTIONS",
b"/res/",
cors_response_headers={
"Access-Control-Allow-Origin": "https://example.com"
},
)
self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
self.assertEqual(
channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
[b"https://example.com"],
"has correct CORS Origin header",
)
self.assertEqual(
channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
None,
"has correct CORS Expose Headers header",
)
def test_custom_cors_msc3886(self) -> None:
"""An OPTIONS request to a known URL should return customised CORS headers even with MSC3886 enabled."""
channel = self._make_request(
b"OPTIONS",
b"/res/",
experimental_cors_msc3886=True,
cors_response_headers={
"Access-Control-Allow-Origin": "https://example.com"
},
)
self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
self.assertEqual(
channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
[b"https://example.com"],
"has correct CORS Origin header",
)
self.assertEqual(
channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
None,
"has correct CORS Expose Headers header",
)
class WrapHtmlRequestHandlerTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeHtmlResource): class TestResource(DirectServeHtmlResource):