Use literals in place of `HTTPStatus` constants in tests (#13463)

This commit is contained in:
Dirk Klimpel 2022-08-05 16:59:09 +02:00 committed by GitHub
parent 3d2cabf966
commit e2ed1b7155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 172 additions and 191 deletions

1
changelog.d/13463.misc Normal file
View File

@ -0,0 +1 @@
Use literals in place of `HTTPStatus` constants in tests.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity) self.assertTrue(complexity > 0, complexity)
@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23) self.assertEqual(complexity, 1.23)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
@ -256,7 +255,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier, RoomVersions.V7.identifier,
), ),
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as # Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that # part of the stripped room state, as the knocking homeserver already has that
@ -294,7 +293,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id), % (room_id, signed_knock_event.event_id),
signed_knock_event_json, signed_knock_event_json,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(200, channel.code, channel.result)
# Check that we got the stripped room state in return # Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"] room_state_events = channel.json_body["knock_state_events"]

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from typing import Any, Dict from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
access_token=self.token, access_token=self.token,
) )
self.assertEqual(req.code, HTTPStatus.OK, req) self.assertEqual(req.code, 200, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None: def test_global_account_data_deleted_upon_deactivation(self) -> None:
""" """

View File

@ -314,4 +314,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", path, content={}, access_token=self.access_token "POST", path, content={}, access_token=self.access_token
) )
self.assertEqual(int(channel.result["code"]), 403) self.assertEqual(channel.code, 403)

View File

@ -1,4 +1,3 @@
from http import HTTPStatus
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -260,7 +259,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
f"/_matrix/client/v3/rooms/{self.room_id}/join", f"/_matrix/client/v3/rooms/{self.room_id}/join",
access_token=self.bob_token, access_token=self.bob_token,
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
# wait for join to arrive over replication # wait for join to arrive over replication
self.replicate() self.replicate()

View File

@ -15,7 +15,6 @@
import inspect import inspect
import itertools import itertools
import logging import logging
from http import HTTPStatus
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -78,7 +77,7 @@ def test_disconnect(
if expect_cancellation: if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED expected_code = HTTP_STATUS_REQUEST_CANCELLED
else: else:
expected_code = HTTPStatus.OK expected_code = 200
request = channel.request request = channel.request
if channel.is_finished(): if channel.is_finished():

View File

@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"}) self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success( filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0) self.store.get_user_filter(user_localpart="apple", filter_id=0)
@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None: def test_add_filter_non_local_user(self) -> None:
@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
) )
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None: def test_get_filter(self) -> None:
@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None: def test_get_filter_non_existant(self) -> None:
@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"404") self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None: def test_get_filter_no_id(self) -> None:
@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
) )
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.code, 400)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import time import time
import urllib.parse import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from urllib.parse import urlencode from urllib.parse import urlencode
@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config( @override_config(
{ {
@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config( @override_config(
{ {
@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min. # than 1min.
@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None: def test_soft_logout(self) -> None:
@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token # we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL) channel = self.make_request(b"GET", TEST_URL)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal # log in as normal
@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"] device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout # more requests with the expired token should still return a soft-logout
self.reactor.advance(3600) self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id) self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False) self.assertEqual(channel.json_body["soft_logout"], False)
@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token b"DELETE", "devices/" + device_id, access_token=access_token
) )
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
# check it's a UI-Auth fail # check it's a UI-Auth fail
self.assertEqual( self.assertEqual(
set(channel.json_body.keys()), set(channel.json_body.keys()),
@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
content={"auth": auth}, content={"auth": auth},
) )
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard logout this session # Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token) channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions # Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token) channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None: def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese") self.register_user("mickey", "cheese")
@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None: def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows""" """GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [ expected_flow_types = [
"m.login.cas", "m.login.cas",
@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None) channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
uri = location_headers[0] uri = location_headers[0]
# hitting that picker should give us some HTML # hitting that picker should give us some HTML
channel = self.make_request("GET", uri) channel = self.make_request("GET", uri)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class # parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8") html = channel.result["body"].decode("utf-8")
@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas", + "&idp=cas",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
cas_uri = location_headers[0] cas_uri = location_headers[0]
@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
saml_uri = location_headers[0] saml_uri = location_headers[0]
@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc", + "&idp=oidc",
) )
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page # that should serve a confirmation page
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type") content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html")) self.assertTrue(content_type_headers[-1].startswith("text/html"))
@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test") self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None: def test_multi_sso_redirect_to_unknown(self) -> None:
@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
) )
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx") channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc") channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url) channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML. # Test that the response is HTML.
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = "" content_type_header_value = ""
for header in channel.result.get("headers", []): for header in channel.result.get("headers", []):
if header[0] == b"Content-Type": if header[0] == b"Content-Type":
@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None: def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey") self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None: def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"}) channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret") channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None: def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000}) channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self) -> None: def test_login_jwt_not_before(self) -> None:
now = int(time.time()) now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self) -> None: def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"}) channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT") self.assertEqual(channel.json_body["error"], "Invalid JWT")
@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim.""" """Test validating the issuer claim."""
# A valid issuer. # A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer. # An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an issuer. # Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_iss_no_config(self) -> None: def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration.""" """Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim.""" """Test validating the audience claim."""
# A valid audience. # A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience. # An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an audience. # Not providing an audience.
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self) -> None: def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration.""" """Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_default_sub(self) -> None: def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim.""" """Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None: def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim.""" """Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"}) channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test") self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None: def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"} params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None: def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"}) channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test") self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual( self.assertEqual(
channel.json_body["error"], channel.json_body["error"],
@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None: def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login""" """Test that the appservice bot can use /login"""
@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None: def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token""" """Test that non-as users cannot login with the as token"""
@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token b"POST", LOGIN_URL, params, access_token=self.service.token
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None: def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token""" """Test that as users cannot login with wrong as token"""
@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token b"POST", LOGIN_URL, params, access_token=self.another_service.token
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None: def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice """Test that users must provide a token when using the appservice
@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC") @skip_unless(HAS_OIDC, "requires OIDC")
@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
# that should redirect to the username picker # that should redirect to the username picker
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
picker_url = location_headers[0] picker_url = location_headers[0]
@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))), ("Content-Length", str(len(content))),
], ],
) )
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0], path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)], custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
) )
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test") self.assertEqual(chan.json_body["user_id"], "@bobby:test")

View File

@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase):
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
channel = self.make_request("POST", path, content={}, access_token=access_token) channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code) self.assertEqual(channel.code, expect_code)
return channel.json_body return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token) channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id] room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"] return room_sync["timeline"]["events"]

View File

@ -70,7 +70,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname} det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@ -91,7 +91,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None: def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists self.appservice = None # no application service exists
@ -100,20 +100,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None: def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666} request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password") self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None: def test_POST_bad_username(self) -> None:
request_data = {"username": 777, "password": "monkey"} request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username") self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None: def test_POST_user_valid(self) -> None:
@ -132,7 +132,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False}) @override_config({"enable_registration": False})
@ -142,7 +142,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -153,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None: def test_POST_disabled_guest_registration(self) -> None:
@ -161,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled") self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@ -171,16 +171,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}") channel = self.make_request(b"POST", url, b"{}")
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None: def test_POST_ratelimiting(self) -> None:
@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True}) @override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None: def test_POST_registration_requires_token(self) -> None:
@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Request without auth to get flows and session # Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one # Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow. # flow would be a subset of another flow.
@ -248,7 +248,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"] completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@ -263,7 +263,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0 # Check the `completed` counter has been incremented and pending is 0
@ -293,21 +293,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid) # Test with non-string (invalid)
params["auth"]["token"] = 1234 params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid) # Test with unknown token (invalid)
params["auth"]["token"] = "1234" params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -361,7 +361,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session2, "session": session2,
} }
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -381,7 +381,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Check auth still fails when using token with session2 # Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2) channel = self.make_request(b"POST", self.url, params2)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -415,7 +415,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session, "session": session,
} }
channel = self.make_request(b"POST", self.url, params) channel = self.make_request(b"POST", self.url, params)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], []) self.assertEqual(channel.json_body["completed"], [])
@ -570,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None: def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow # with the stock config, we only expect the dummy flow
@ -593,7 +593,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
self.assertCountEqual( self.assertCountEqual(
@ -625,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
) )
def test_advertised_flows_no_msisdn_email_required(self) -> None: def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}") channel = self.make_request(b"POST", self.url, b"{}")
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"] flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid # with the stock config, we expect all four combinations of 3pid
@ -797,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
@ -823,12 +823,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity" url = "/_synapse/admin/v1/account_validity/validity"
request_data = {"user_id": user_id} request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated # The specific endpoint doesn't matter, all we need is an authenticated
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None: def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
@ -844,12 +844,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False, "enable_renewal_emails": False,
} }
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated # The specific endpoint doesn't matter, all we need is an authenticated
# endpoint. # endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
@ -868,18 +868,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False, "enable_renewal_emails": False,
} }
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out # Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok) channel = self.make_request(b"POST", "/logout", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts) # Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions # Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok) channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@ -954,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -972,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again. # Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -992,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed. # succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds()) self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok) channel = self.make_request(b"GET", "/sync", access_token=tok)
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None: def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as # Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123" url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url) channel = self.make_request(b"GET", url)
self.assertEqual(channel.result["code"], b"404", channel.result) self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back. # Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type") content_type = channel.headers.getRawHeaders(b"Content-Type")
@ -1023,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail", "/_matrix/client/unstable/account_validity/send_mail",
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -1096,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail", "/_matrix/client/unstable/account_validity/send_mail",
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -1176,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True) self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None: def test_GET_token_invalid(self) -> None:
@ -1185,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False) self.assertEqual(channel.json_body["valid"], False)
@override_config( @override_config(
@ -1201,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
) )
if i == 5: if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"]) retry_after_ms = int(channel.json_body["retry_after_ms"])
else: else:
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0) self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@ -1212,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET", b"GET",
f"{self.url}?token={token}", f"{self.url}?token={token}",
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, msg=channel.result)

View File

@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok "POST", self.report_path, data, access_token=self.other_user_tok
) )
self.assertEqual( self.assertEqual(response_status, channel.code, msg=channel.result["body"])
response_status, int(channel.result["code"]), msg=channel.result["body"]
)

View File

@ -155,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once() callback.assert_called_once()
@ -173,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{}, {},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
""" """
@ -211,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok, access_token=self.tok,
) )
# Check the error code # Check the error code
self.assertEqual(channel.result["code"], b"429", channel.result) self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected # Check the JSON body has had the `nasty` key injected
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
@ -260,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"}, {"x": "x"},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"] event_id = channel.json_body["event_id"]
# ... and check that it got modified # ... and check that it got modified
@ -269,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y") self.assertEqual(ev["content"]["x"], "y")
@ -298,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"] orig_event_id = channel.json_body["event_id"]
channel = self.make_request( channel = self.make_request(
@ -315,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"] edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified # ... and check that they both got modified
@ -324,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@ -333,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY") self.assertEqual(ev["content"]["body"], "EDITED BODY")
@ -379,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}, },
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"] event_id = channel.json_body["event_id"]
@ -388,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys()) self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar") self.assertEqual(channel.json_body["content"]["foo"], "bar")

View File

@ -140,7 +140,7 @@ class RestHelper:
custom_headers=custom_headers, custom_headers=custom_headers,
) )
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK: if expect_code == HTTPStatus.OK:
@ -213,11 +213,9 @@ class RestHelper:
data, data,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -312,11 +310,9 @@ class RestHelper:
data, data,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -396,11 +392,9 @@ class RestHelper:
custom_headers=custom_headers, custom_headers=custom_headers,
) )
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -449,11 +443,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content) channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
assert ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
int(channel.result["code"]) == expect_code
), "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )
@ -545,7 +537,7 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
int(channel.result["code"]), channel.code,
channel.result["body"], channel.result["body"],
) )

View File

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from tests import unittest from tests import unittest
@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase):
def test_health(self) -> None: def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False) channel = self.make_request("GET", "/health", shorthand=False)
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK") self.assertEqual(channel.result["body"], b"OK")

View File

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource from synapse.rest.well_known import well_known_resource
@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -57,7 +55,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) self.assertEqual(channel.code, 404)
@unittest.override_config( @unittest.override_config(
{ {
@ -71,7 +69,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False "GET", "/.well-known/matrix/client", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{ {
@ -87,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
channel.json_body, channel.json_body,
{"m.server": "test:443"}, {"m.server": "test:443"},
@ -97,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False "GET", "/.well-known/matrix/server", shorthand=False
) )
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) self.assertEqual(channel.code, 404)

View File

@ -104,7 +104,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
) )
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.code, 500)
def test_callback_indirect_exception(self) -> None: def test_callback_indirect_exception(self) -> None:
""" """
@ -130,7 +130,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
) )
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.code, 500)
def test_callback_synapseerror(self) -> None: def test_callback_synapseerror(self) -> None:
""" """
@ -150,7 +150,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
) )
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
) )
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
@ -203,7 +203,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo" self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)
@ -242,7 +242,7 @@ class OptionsResourceTests(unittest.TestCase):
def test_unknown_options_request(self) -> None: def test_unknown_options_request(self) -> None:
"""An OPTIONS requests to an unknown URL still returns 204 No Content.""" """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/") channel = self._make_request(b"OPTIONS", b"/foo/")
self.assertEqual(channel.result["code"], b"204") self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added # Ensure the correct CORS headers have been added
@ -262,7 +262,7 @@ class OptionsResourceTests(unittest.TestCase):
def test_known_options_request(self) -> None: def test_known_options_request(self) -> None:
"""An OPTIONS requests to an known URL still returns 204 No Content.""" """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/") channel = self._make_request(b"OPTIONS", b"/res/")
self.assertEqual(channel.result["code"], b"204") self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added # Ensure the correct CORS headers have been added
@ -282,12 +282,12 @@ class OptionsResourceTests(unittest.TestCase):
def test_unknown_request(self) -> None: def test_unknown_request(self) -> None:
"""A non-OPTIONS request to an unknown URL should 404.""" """A non-OPTIONS request to an unknown URL should 404."""
channel = self._make_request(b"GET", b"/foo/") channel = self._make_request(b"GET", b"/foo/")
self.assertEqual(channel.result["code"], b"404") self.assertEqual(channel.code, 404)
def test_known_request(self) -> None: def test_known_request(self) -> None:
"""A non-OPTIONS request to an known URL should query the proper resource.""" """A non-OPTIONS request to an known URL should query the proper resource."""
channel = self._make_request(b"GET", b"/res/") channel = self._make_request(b"GET", b"/res/")
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"/res/") self.assertEqual(channel.result["body"], b"/res/")
@ -314,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
body = channel.result["body"] body = channel.result["body"]
self.assertEqual(body, b"response") self.assertEqual(body, b"response")
@ -334,7 +334,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
) )
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.code, 301)
headers = channel.result["headers"] headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"] location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"]) self.assertEqual(location_headers, [b"/look/an/eagle"])
@ -357,7 +357,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
) )
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.code, 304)
headers = channel.result["headers"] headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"] location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"]) self.assertEqual(location_headers, [b"/no/over/there"])
@ -378,7 +378,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path" self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
) )
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.code, 200)
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)

View File

@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
request_data = {"username": "kermit", "password": "monkey"} request_data = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data) channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, channel.result)
self.assertTrue(channel.json_body is not None) self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["session"], str) self.assertIsInstance(channel.json_body["session"], str)
@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
# We don't bother checking that the response is correct - we'll leave that to # We don't bother checking that the response is correct - we'll leave that to
# other tests. We just want to make sure we're on the right path. # other tests. We just want to make sure we're on the right path.
self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.code, 401, channel.result)
# Finish the UI auth for terms # Finish the UI auth for terms
request_data = { request_data = {
@ -112,7 +112,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
# We're interested in getting a response that looks like a successful # We're interested in getting a response that looks like a successful
# registration, not so much that the details are exactly what we want. # registration, not so much that the details are exactly what we want.
self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(channel.json_body is not None) self.assertTrue(channel.json_body is not None)
self.assertIsInstance(channel.json_body["user_id"], str) self.assertIsInstance(channel.json_body["user_id"], str)