Merge pull request #932 from matrix-org/rav/register_refactor
Further registration refactoring
This commit is contained in:
commit
e967bc86e7
|
@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (str) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
|
generate_token (bool): Whether a new access token should be
|
||||||
|
generated. Having this be True should be considered deprecated,
|
||||||
|
since it offers no means of associating a device_id with the
|
||||||
|
access_token. Instead you should call auth_handler.issue_access_token
|
||||||
|
after registration.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id, allowed_appservice=service
|
user_id, allowed_appservice=service
|
||||||
)
|
)
|
||||||
|
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
|
||||||
password_hash="",
|
password_hash="",
|
||||||
appservice_id=service_id,
|
appservice_id=service_id,
|
||||||
create_profile_with_localpart=user.localpart,
|
create_profile_with_localpart=user.localpart,
|
||||||
)
|
)
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_recaptcha(self, ip, private_key, challenge, response):
|
def check_recaptcha(self, ip, private_key, challenge, response):
|
||||||
|
|
|
@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
# TODO: persistent storage
|
# TODO: persistent storage
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.enable_registration = hs.config.enable_registration
|
self.enable_registration = hs.config.enable_registration
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
|
@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
user_localpart = register_json["user"].encode("utf-8")
|
user_localpart = register_json["user"].encode("utf-8")
|
||||||
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
(user_id, token) = yield handler.appservice_register(
|
user_id = yield handler.appservice_register(
|
||||||
user_localpart, as_token
|
user_localpart, as_token
|
||||||
)
|
)
|
||||||
|
token = yield self.auth_handler.issue_access_token(user_id)
|
||||||
self._remove_session(session)
|
self._remove_session(session)
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
|
@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
add_email = True
|
add_email = True
|
||||||
|
|
||||||
access_token = yield self.auth_handler.issue_access_token(
|
result = yield self._create_registration_details(
|
||||||
registered_user_id
|
registered_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_email and result and LoginType.EMAIL_IDENTITY in result:
|
if add_email and result and LoginType.EMAIL_IDENTITY in result:
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
yield self._register_email_threepid(
|
yield self._register_email_threepid(
|
||||||
registered_user_id, threepid, access_token,
|
registered_user_id, threepid, result["access_token"],
|
||||||
params.get("bind_email")
|
params.get("bind_email")
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self._create_registration_details(registered_user_id,
|
|
||||||
access_token)
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
|
@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_appservice_registration(self, username, as_token):
|
def _do_appservice_registration(self, username, as_token):
|
||||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
user_id = yield self.registration_handler.appservice_register(
|
||||||
username, as_token
|
username, as_token
|
||||||
)
|
)
|
||||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
defer.returnValue((yield self._create_registration_details(user_id)))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_shared_secret_registration(self, username, password, mac):
|
def _do_shared_secret_registration(self, username, password, mac):
|
||||||
|
@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet):
|
||||||
403, "HMAC incorrect",
|
403, "HMAC incorrect",
|
||||||
)
|
)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, _) = yield self.registration_handler.register(
|
||||||
localpart=username, password=password
|
localpart=username, password=password, generate_token=False,
|
||||||
)
|
)
|
||||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
|
||||||
|
result = yield self._create_registration_details(user_id)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _register_email_threepid(self, user_id, threepid, token, bind_email):
|
def _register_email_threepid(self, user_id, threepid, token, bind_email):
|
||||||
|
@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet):
|
||||||
defer.returnValue()
|
defer.returnValue()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_registration_details(self, user_id, token):
|
def _create_registration_details(self, user_id):
|
||||||
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
|
"""Complete registration of newly-registered user
|
||||||
|
|
||||||
|
Issues access_token and refresh_token, and builds the success response
|
||||||
|
body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(str) user_id: full canonical @user:id
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: (object) dictionary for response from /register
|
||||||
|
"""
|
||||||
|
|
||||||
|
access_token = yield self.auth_handler.issue_access_token(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
})
|
})
|
||||||
|
@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
make_guest=True
|
make_guest=True
|
||||||
)
|
)
|
||||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
access_token = self.auth_handler.generate_access_token(
|
||||||
|
user_id, ["guest = true"]
|
||||||
|
)
|
||||||
|
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
||||||
|
# so long as we don't return a refresh_token here.
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
|
|
|
@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, user_id, token, password_hash,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None, admin=False):
|
create_profile_with_localpart=None, admin=False):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The desired user ID to register.
|
user_id (str): The desired user ID to register.
|
||||||
token (str): The desired access token to use for this user.
|
token (str): The desired access token to use for this user. If this
|
||||||
|
is not None, the given access token is associated with the user
|
||||||
|
id.
|
||||||
password_hash (str): Optional. The password hash for this user.
|
password_hash (str): Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest (bool): Optional. Whether this is a guest account being
|
||||||
upgraded to a non-guest account.
|
upgraded to a non-guest account.
|
||||||
|
|
|
@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"id": "1234"
|
"id": "1234"
|
||||||
}
|
}
|
||||||
self.registration_handler.appservice_register = Mock(
|
self.registration_handler.appservice_register = Mock(
|
||||||
return_value=(user_id, token)
|
return_value=user_id
|
||||||
)
|
)
|
||||||
|
self.auth_handler.issue_access_token = Mock(return_value=token)
|
||||||
|
|
||||||
(code, result) = yield self.servlet.on_POST(self.request)
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
self.assertEquals(code, 200)
|
self.assertEquals(code, 200)
|
||||||
det_data = {
|
det_data = {
|
||||||
|
@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, result)
|
self.assertDictContainsSubset(det_data, result)
|
||||||
self.assertIn("refresh_token", result)
|
self.assertIn("refresh_token", result)
|
||||||
|
self.auth_handler.issue_access_token.assert_called_once_with(
|
||||||
|
user_id)
|
||||||
|
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self):
|
||||||
self.hs.config.enable_registration = False
|
self.hs.config.enable_registration = False
|
||||||
|
|
Loading…
Reference in New Issue