Add a cache around server ACL checking (#16360)
* Pre-compiles the server ACLs onto an object per room and invalidates them when new events come in. * Converts the server ACL checking into Rust.
This commit is contained in:
parent
17800a0e97
commit
f84da3c32e
|
@ -0,0 +1 @@
|
|||
Cache server ACL checking.
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! An implementation of Matrix server ACL rules.
|
||||
|
||||
use std::net::Ipv4Addr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::Error;
|
||||
use pyo3::prelude::*;
|
||||
use regex::Regex;
|
||||
|
||||
use crate::push::utils::{glob_to_regex, GlobMatchType};
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "acl")?;
|
||||
child_module.add_class::<ServerAclEvaluator>()?;
|
||||
|
||||
m.add_submodule(child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import acl` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.acl", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[pyclass(frozen)]
|
||||
pub struct ServerAclEvaluator {
|
||||
allow_ip_literals: bool,
|
||||
allow: Vec<Regex>,
|
||||
deny: Vec<Regex>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl ServerAclEvaluator {
|
||||
#[new]
|
||||
pub fn py_new(
|
||||
allow_ip_literals: bool,
|
||||
allow: Vec<&str>,
|
||||
deny: Vec<&str>,
|
||||
) -> Result<Self, Error> {
|
||||
let allow = allow
|
||||
.iter()
|
||||
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
|
||||
.collect::<Result<_, _>>()?;
|
||||
let deny = deny
|
||||
.iter()
|
||||
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
Ok(ServerAclEvaluator {
|
||||
allow_ip_literals,
|
||||
allow,
|
||||
deny,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn server_matches_acl_event(&self, server_name: &str) -> bool {
|
||||
// first of all, check if literal IPs are blocked, and if so, whether the
|
||||
// server name is a literal IP
|
||||
if !self.allow_ip_literals {
|
||||
// check for ipv6 literals. These start with '['.
|
||||
if server_name.starts_with('[') {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check for ipv4 literals. We can just lift the routine from std::net.
|
||||
if Ipv4Addr::from_str(server_name).is_ok() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// next, check the deny list
|
||||
if self.deny.iter().any(|e| e.is_match(server_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// then the allow list.
|
||||
if self.allow.iter().any(|e| e.is_match(server_name)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// everything else should be rejected.
|
||||
false
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ use lazy_static::lazy_static;
|
|||
use pyo3::prelude::*;
|
||||
use pyo3_log::ResetHandle;
|
||||
|
||||
pub mod acl;
|
||||
pub mod push;
|
||||
|
||||
lazy_static! {
|
||||
|
@ -38,6 +39,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?;
|
||||
|
||||
acl::register_module(py, m)?;
|
||||
push::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
class ServerAclEvaluator:
|
||||
def __init__(
|
||||
self, allow_ip_literals: bool, allow: List[str], deny: List[str]
|
||||
) -> None: ...
|
||||
def server_matches_acl_event(self, server_name: str) -> bool: ...
|
|
@ -39,9 +39,9 @@ from synapse.events.utils import (
|
|||
CANONICALJSON_MIN_INT,
|
||||
validate_canonicaljson,
|
||||
)
|
||||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.http.servlet import validate_json_object
|
||||
from synapse.rest.models import RequestBodyModel
|
||||
from synapse.storage.controllers.state import server_acl_evaluator_from_event
|
||||
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
|
||||
|
||||
|
||||
|
@ -106,7 +106,10 @@ class EventValidator:
|
|||
self._validate_retention(event)
|
||||
|
||||
elif event.type == EventTypes.ServerACL:
|
||||
if not server_matches_acl_event(config.server.server_name, event):
|
||||
server_acl_evaluator = server_acl_evaluator_from_event(event)
|
||||
if not server_acl_evaluator.server_matches_acl_event(
|
||||
config.server.server_name
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Can't create an ACL event that denies the local server"
|
||||
)
|
||||
|
|
|
@ -29,10 +29,8 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from matrix_common.regex import glob_to_regex
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import (
|
||||
|
@ -1324,77 +1322,15 @@ class FederationServer(FederationBase):
|
|||
Raises:
|
||||
AuthError if the server does not match the ACL
|
||||
"""
|
||||
acl_event = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.ServerACL, ""
|
||||
server_acl_evaluator = (
|
||||
await self._storage_controllers.state.get_server_acl_for_room(room_id)
|
||||
)
|
||||
if not acl_event or server_matches_acl_event(server_name, acl_event):
|
||||
return
|
||||
|
||||
if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
|
||||
server_name
|
||||
):
|
||||
raise AuthError(code=403, msg="Server is banned from room")
|
||||
|
||||
|
||||
def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
|
||||
"""Check if the given server is allowed by the ACL event
|
||||
|
||||
Args:
|
||||
server_name: name of server, without any port part
|
||||
acl_event: m.room.server_acl event
|
||||
|
||||
Returns:
|
||||
True if this server is allowed by the ACLs
|
||||
"""
|
||||
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
|
||||
|
||||
# first of all, check if literal IPs are blocked, and if so, whether the
|
||||
# server name is a literal IP
|
||||
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
|
||||
if not isinstance(allow_ip_literals, bool):
|
||||
logger.warning("Ignoring non-bool allow_ip_literals flag")
|
||||
allow_ip_literals = True
|
||||
if not allow_ip_literals:
|
||||
# check for ipv6 literals. These start with '['.
|
||||
if server_name[0] == "[":
|
||||
return False
|
||||
|
||||
# check for ipv4 literals. We can just lift the routine from twisted.
|
||||
if isIPAddress(server_name):
|
||||
return False
|
||||
|
||||
# next, check the deny list
|
||||
deny = acl_event.content.get("deny", [])
|
||||
if not isinstance(deny, (list, tuple)):
|
||||
logger.warning("Ignoring non-list deny ACL %s", deny)
|
||||
deny = []
|
||||
for e in deny:
|
||||
if _acl_entry_matches(server_name, e):
|
||||
# logger.info("%s matched deny rule %s", server_name, e)
|
||||
return False
|
||||
|
||||
# then the allow list.
|
||||
allow = acl_event.content.get("allow", [])
|
||||
if not isinstance(allow, (list, tuple)):
|
||||
logger.warning("Ignoring non-list allow ACL %s", allow)
|
||||
allow = []
|
||||
for e in allow:
|
||||
if _acl_entry_matches(server_name, e):
|
||||
# logger.info("%s matched allow rule %s", server_name, e)
|
||||
return True
|
||||
|
||||
# everything else should be rejected.
|
||||
# logger.info("%s fell through", server_name)
|
||||
return False
|
||||
|
||||
|
||||
def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
|
||||
if not isinstance(acl_entry, str):
|
||||
logger.warning(
|
||||
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
|
||||
)
|
||||
return False
|
||||
regex = glob_to_regex(acl_entry)
|
||||
return bool(regex.match(server_name))
|
||||
|
||||
|
||||
class FederationHandlerRegistry:
|
||||
"""Allows classes to register themselves as handlers for a given EDU or
|
||||
query type for incoming federation traffic.
|
||||
|
|
|
@ -2342,6 +2342,12 @@ class FederationEventHandler:
|
|||
# TODO retrieve the previous state, and exclude join -> join transitions
|
||||
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
|
||||
|
||||
# If this is a server ACL event, clear the cache in the storage controller.
|
||||
if event.type == EventTypes.ServerACL:
|
||||
self._state_storage_controller.get_server_acl_for_room.invalidate(
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev: EventBase) -> None:
|
||||
"""
|
||||
Do some early sanity checks of a received event
|
||||
|
|
|
@ -1730,6 +1730,11 @@ class EventCreationHandler:
|
|||
event.event_id, event.room_id
|
||||
)
|
||||
|
||||
if event.type == EventTypes.ServerACL:
|
||||
self._storage_controllers.state.get_server_acl_for_room.invalidate(
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
await self._maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
|
|
|
@ -205,6 +205,12 @@ class ReplicationDataHandler:
|
|||
self.notifier.notify_user_joined_room(
|
||||
row.data.event_id, row.data.room_id
|
||||
)
|
||||
|
||||
# If this is a server ACL event, clear the cache in the storage controller.
|
||||
if row.data.type == EventTypes.ServerACL:
|
||||
self._state_storage_controller.get_server_acl_for_room.invalidate(
|
||||
(row.data.room_id,)
|
||||
)
|
||||
elif stream_name == UnPartialStatedRoomStream.NAME:
|
||||
for row in rows:
|
||||
assert isinstance(row, UnPartialStatedRoomStreamRow)
|
||||
|
|
|
@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import (
|
|||
PartialCurrentStateTracker,
|
||||
PartialStateEventsTracker,
|
||||
)
|
||||
from synapse.synapse_rust.acl import ServerAclEvaluator
|
||||
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
@ -501,6 +502,31 @@ class StateStorageController:
|
|||
|
||||
return event.content.get("alias")
|
||||
|
||||
@cached()
|
||||
async def get_server_acl_for_room(
|
||||
self, room_id: str
|
||||
) -> Optional[ServerAclEvaluator]:
|
||||
"""Get the server ACL evaluator for room, if any
|
||||
|
||||
This does up-front parsing of the content to ignore bad data and pre-compile
|
||||
regular expressions.
|
||||
|
||||
Args:
|
||||
room_id: The room ID
|
||||
|
||||
Returns:
|
||||
The server ACL evaluator, if any
|
||||
"""
|
||||
|
||||
acl_event = await self.get_current_state_event(
|
||||
room_id, EventTypes.ServerACL, ""
|
||||
)
|
||||
|
||||
if not acl_event:
|
||||
return None
|
||||
|
||||
return server_acl_evaluator_from_event(acl_event)
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
async def get_current_state_deltas(
|
||||
|
@ -760,3 +786,36 @@ class StateStorageController:
|
|||
cache.state_group = object()
|
||||
|
||||
return frozenset(cache.hosts_to_joined_users)
|
||||
|
||||
|
||||
def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
|
||||
"""
|
||||
Create a ServerAclEvaluator from a m.room.server_acl event's content.
|
||||
|
||||
This does up-front parsing of the content to ignore bad data. It then creates
|
||||
the ServerAclEvaluator which will pre-compile regular expressions from the globs.
|
||||
"""
|
||||
|
||||
# first of all, parse if literal IPs are blocked.
|
||||
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
|
||||
if not isinstance(allow_ip_literals, bool):
|
||||
logger.warning("Ignoring non-bool allow_ip_literals flag")
|
||||
allow_ip_literals = True
|
||||
|
||||
# next, parse the deny list by ignoring any non-strings.
|
||||
deny = acl_event.content.get("deny", [])
|
||||
if not isinstance(deny, (list, tuple)):
|
||||
logger.warning("Ignoring non-list deny ACL %s", deny)
|
||||
deny = []
|
||||
else:
|
||||
deny = [s for s in deny if isinstance(s, str)]
|
||||
|
||||
# then the allow list.
|
||||
allow = acl_event.content.get("allow", [])
|
||||
if not isinstance(allow, (list, tuple)):
|
||||
logger.warning("Ignoring non-list allow ACL %s", allow)
|
||||
allow = []
|
||||
else:
|
||||
allow = [s for s in allow if isinstance(s, str)]
|
||||
|
||||
return ServerAclEvaluator(allow_ip_literals, allow, deny)
|
||||
|
|
|
@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.controllers.state import server_acl_evaluator_from_event
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
|
@ -67,37 +67,46 @@ class ServerACLsTestCase(unittest.TestCase):
|
|||
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
|
||||
logging.info("ACL event: %s", e.content)
|
||||
|
||||
self.assertFalse(server_matches_acl_event("evil.com", e))
|
||||
self.assertFalse(server_matches_acl_event("EVIL.COM", e))
|
||||
server_acl_evalutor = server_acl_evaluator_from_event(e)
|
||||
|
||||
self.assertTrue(server_matches_acl_event("evil.com.au", e))
|
||||
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
|
||||
self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com"))
|
||||
self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM"))
|
||||
|
||||
self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au"))
|
||||
self.assertTrue(
|
||||
server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com")
|
||||
)
|
||||
|
||||
def test_block_ip_literals(self) -> None:
|
||||
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
|
||||
logging.info("ACL event: %s", e.content)
|
||||
|
||||
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
|
||||
self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
|
||||
self.assertFalse(server_matches_acl_event("[1:2::]", e))
|
||||
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
|
||||
server_acl_evalutor = server_acl_evaluator_from_event(e)
|
||||
|
||||
self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4"))
|
||||
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4"))
|
||||
self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]"))
|
||||
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4"))
|
||||
|
||||
def test_wildcard_matching(self) -> None:
|
||||
e = _create_acl_event({"allow": ["good*.com"]})
|
||||
|
||||
server_acl_evalutor = server_acl_evaluator_from_event(e)
|
||||
|
||||
self.assertTrue(
|
||||
server_matches_acl_event("good.com", e),
|
||||
server_acl_evalutor.server_matches_acl_event("good.com"),
|
||||
"* matches 0 characters",
|
||||
)
|
||||
self.assertTrue(
|
||||
server_matches_acl_event("GOOD.COM", e),
|
||||
server_acl_evalutor.server_matches_acl_event("GOOD.COM"),
|
||||
"pattern is case-insensitive",
|
||||
)
|
||||
self.assertTrue(
|
||||
server_matches_acl_event("good.aa.com", e),
|
||||
server_acl_evalutor.server_matches_acl_event("good.aa.com"),
|
||||
"* matches several characters, including '.'",
|
||||
)
|
||||
self.assertFalse(
|
||||
server_matches_acl_event("ishgood.com", e),
|
||||
server_acl_evalutor.server_matches_acl_event("ishgood.com"),
|
||||
"pattern does not allow prefixes",
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue