Actually enforce guest + return www-authenticate header
This commit is contained in:
parent
28a9663bdf
commit
5fe96082d0
|
@ -25,7 +25,12 @@ from twisted.web.client import readBody
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
from synapse.api.auth.base import BaseAuth
|
from synapse.api.auth.base import BaseAuth
|
||||||
from synapse.api.errors import AuthError, InvalidClientTokenError, StoreError
|
from synapse.api.errors import (
|
||||||
|
AuthError,
|
||||||
|
InvalidClientTokenError,
|
||||||
|
OAuthInsufficientScopeError,
|
||||||
|
StoreError,
|
||||||
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import Requester, UserID, create_requester
|
from synapse.types import Requester, UserID, create_requester
|
||||||
|
@ -152,7 +157,16 @@ class OAuthDelegatedAuth(BaseAuth):
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
) -> Requester:
|
) -> Requester:
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
return await self.get_user_by_access_token(access_token, allow_expired)
|
|
||||||
|
# TODO: we probably want to assert the allow_guest inside this call so that we don't provision the user if they don't have enough permission:
|
||||||
|
requester = await self.get_user_by_access_token(access_token, allow_expired)
|
||||||
|
|
||||||
|
if not allow_guest and requester.is_guest:
|
||||||
|
raise OAuthInsufficientScopeError(
|
||||||
|
["urn:matrix:org.matrix.msc2967.client:api:*"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return requester
|
||||||
|
|
||||||
async def get_user_by_access_token(
|
async def get_user_by_access_token(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -119,14 +119,20 @@ class Codes(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
"""An exception with integer code and message string attributes.
|
"""An exception with integer code, a message string attributes and optional headers.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
code: HTTP error code
|
code: HTTP error code
|
||||||
msg: string describing the error
|
msg: string describing the error
|
||||||
|
headers: optional response headers to send
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, code: Union[int, HTTPStatus], msg: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
code: Union[int, HTTPStatus],
|
||||||
|
msg: str,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
super().__init__("%d: %s" % (code, msg))
|
super().__init__("%d: %s" % (code, msg))
|
||||||
|
|
||||||
# Some calls to this method pass instances of http.HTTPStatus for `code`.
|
# Some calls to this method pass instances of http.HTTPStatus for `code`.
|
||||||
|
@ -137,6 +143,7 @@ class CodeMessageException(RuntimeError):
|
||||||
# To eliminate this behaviour, we convert them to their integer equivalents here.
|
# To eliminate this behaviour, we convert them to their integer equivalents here.
|
||||||
self.code = int(code)
|
self.code = int(code)
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
|
||||||
class RedirectException(CodeMessageException):
|
class RedirectException(CodeMessageException):
|
||||||
|
@ -182,6 +189,7 @@ class SynapseError(CodeMessageException):
|
||||||
msg: str,
|
msg: str,
|
||||||
errcode: str = Codes.UNKNOWN,
|
errcode: str = Codes.UNKNOWN,
|
||||||
additional_fields: Optional[Dict] = None,
|
additional_fields: Optional[Dict] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
"""Constructs a synapse error.
|
"""Constructs a synapse error.
|
||||||
|
|
||||||
|
@ -190,7 +198,7 @@ class SynapseError(CodeMessageException):
|
||||||
msg: The human-readable error message.
|
msg: The human-readable error message.
|
||||||
errcode: The matrix error code e.g 'M_FORBIDDEN'
|
errcode: The matrix error code e.g 'M_FORBIDDEN'
|
||||||
"""
|
"""
|
||||||
super().__init__(code, msg)
|
super().__init__(code, msg, headers)
|
||||||
self.errcode = errcode
|
self.errcode = errcode
|
||||||
if additional_fields is None:
|
if additional_fields is None:
|
||||||
self._additional_fields: Dict = {}
|
self._additional_fields: Dict = {}
|
||||||
|
@ -335,6 +343,20 @@ class AuthError(SynapseError):
|
||||||
super().__init__(code, msg, errcode, additional_fields)
|
super().__init__(code, msg, errcode, additional_fields)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthInsufficientScopeError(SynapseError):
|
||||||
|
"""An error raised when the caller does not have sufficient scope to perform the requested action"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
required_scopes: List[str],
|
||||||
|
):
|
||||||
|
headers = {
|
||||||
|
"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"'
|
||||||
|
% (" ".join(required_scopes))
|
||||||
|
}
|
||||||
|
super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None, headers)
|
||||||
|
|
||||||
|
|
||||||
class UnstableSpecAuthError(AuthError):
|
class UnstableSpecAuthError(AuthError):
|
||||||
"""An error raised when a new error code is being proposed to replace a previous one.
|
"""An error raised when a new error code is being proposed to replace a previous one.
|
||||||
This error will return a "org.matrix.unstable.errcode" property with the new error code,
|
This error will return a "org.matrix.unstable.errcode" property with the new error code,
|
||||||
|
|
|
@ -111,6 +111,9 @@ def return_json_error(
|
||||||
exc: SynapseError = f.value # type: ignore
|
exc: SynapseError = f.value # type: ignore
|
||||||
error_code = exc.code
|
error_code = exc.code
|
||||||
error_dict = exc.error_dict(config)
|
error_dict = exc.error_dict(config)
|
||||||
|
if exc.headers is not None:
|
||||||
|
for header, value in exc.headers.items():
|
||||||
|
request.setHeader(header, value)
|
||||||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||||
elif f.check(CancelledError):
|
elif f.check(CancelledError):
|
||||||
error_code = HTTP_STATUS_REQUEST_CANCELLED
|
error_code = HTTP_STATUS_REQUEST_CANCELLED
|
||||||
|
@ -172,6 +175,9 @@ def return_html_error(
|
||||||
cme: CodeMessageException = f.value # type: ignore
|
cme: CodeMessageException = f.value # type: ignore
|
||||||
code = cme.code
|
code = cme.code
|
||||||
msg = cme.msg
|
msg = cme.msg
|
||||||
|
if cme.headers is not None:
|
||||||
|
for header, value in cme.headers.items():
|
||||||
|
request.setHeader(header, value)
|
||||||
|
|
||||||
if isinstance(cme, RedirectException):
|
if isinstance(cme, RedirectException):
|
||||||
logger.info("%s redirect to %s", request, cme.location)
|
logger.info("%s redirect to %s", request, cme.location)
|
||||||
|
|
|
@ -17,7 +17,8 @@ from urllib.parse import parse_qs
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import InvalidClientTokenError
|
from synapse.api.errors import InvalidClientTokenError, OAuthInsufficientScopeError
|
||||||
|
from synapse.rest.client import devices
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -82,6 +83,10 @@ async def get_json(url: str) -> JsonDict:
|
||||||
|
|
||||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||||
class MSC3861OAuthDelegation(HomeserverTestCase):
|
class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
devices.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
def default_config(self) -> Dict[str, Any]:
|
def default_config(self) -> Dict[str, Any]:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["public_baseurl"] = BASE_URL
|
config["public_baseurl"] = BASE_URL
|
||||||
|
@ -314,7 +319,37 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(requester.device_id, DEVICE)
|
self.assertEqual(requester.device_id, DEVICE)
|
||||||
|
|
||||||
def test_active_guest_with_device(self) -> None:
|
def test_active_guest_not_allowed(self) -> None:
|
||||||
|
"""The handler should return an insufficient scope error."""
|
||||||
|
|
||||||
|
self.http_client.request = simple_async_mock(
|
||||||
|
return_value=FakeResponse.json(
|
||||||
|
code=200,
|
||||||
|
payload={
|
||||||
|
"active": True,
|
||||||
|
"sub": SUBJECT,
|
||||||
|
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||||
|
"username": USERNAME,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request = Mock(args={})
|
||||||
|
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||||
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
error = self.get_failure(
|
||||||
|
self.auth.get_user_by_req(request), OAuthInsufficientScopeError
|
||||||
|
)
|
||||||
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||||
|
self.http_client.request.assert_called_once_with(
|
||||||
|
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||||
|
)
|
||||||
|
self._assertParams()
|
||||||
|
self.assertEqual(
|
||||||
|
getattr(error.value, "headers", {})["WWW-Authenticate"],
|
||||||
|
'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"',
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_active_guest_allowed(self) -> None:
|
||||||
"""The handler should return a requester with guest user rights and a device ID."""
|
"""The handler should return a requester with guest user rights and a device ID."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = simple_async_mock(
|
||||||
|
@ -331,7 +366,9 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(
|
||||||
|
self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
)
|
||||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||||
self.http_client.request.assert_called_once_with(
|
self.http_client.request.assert_called_once_with(
|
||||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||||
|
|
Loading…
Reference in New Issue