Merge pull request #932 from matrix-org/rav/register_refactor

Further registration refactoring
This commit is contained in:
David Baker 2016-07-20 11:03:33 +01:00 committed by GitHub
commit e967bc86e7
5 changed files with 57 additions and 22 deletions

View File

@ -101,6 +101,11 @@ class RegistrationHandler(BaseHandler):
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
defer.returnValue((user_id, token)) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):

View File

@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
(user_id, token) = yield handler.appservice_register( user_id = yield handler.appservice_register(
user_localpart, as_token user_localpart, as_token
) )
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session) self._remove_session(session)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,

View File

@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet):
add_email = True add_email = True
access_token = yield self.auth_handler.issue_access_token( result = yield self._create_registration_details(
registered_user_id registered_user_id
) )
if add_email and result and LoginType.EMAIL_IDENTITY in result: if add_email and result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
yield self._register_email_threepid( yield self._register_email_threepid(
registered_user_id, threepid, access_token, registered_user_id, threepid, result["access_token"],
params.get("bind_email") params.get("bind_email")
) )
result = yield self._create_registration_details(registered_user_id,
access_token)
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token): def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register( user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
defer.returnValue((yield self._create_registration_details(user_id, token))) defer.returnValue((yield self._create_registration_details(user_id)))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac): def _do_shared_secret_registration(self, username, password, mac):
@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, token) = yield self.registration_handler.register( (user_id, _) = yield self.registration_handler.register(
localpart=username, password=password localpart=username, password=password, generate_token=False,
) )
defer.returnValue((yield self._create_registration_details(user_id, token)))
result = yield self._create_registration_details(user_id)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token, bind_email): def _register_email_threepid(self, user_id, threepid, token, bind_email):
@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet):
defer.returnValue() defer.returnValue()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_registration_details(self, user_id, token): def _create_registration_details(self, user_id):
refresh_token = yield self.auth_handler.issue_refresh_token(user_id) """Complete registration of newly-registered user
Issues access_token and refresh_token, and builds the success response
body.
Args:
(str) user_id: full canonical @user:id
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
access_token = yield self.auth_handler.issue_access_token(
user_id
)
refresh_token = yield self.auth_handler.issue_refresh_token(
user_id
)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token, "refresh_token": refresh_token,
}) })
@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet):
generate_token=False, generate_token=False,
make_guest=True make_guest=True
) )
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,

View File

@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_localpart=None, admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
user_id (str): The desired user ID to register. user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str): Optional. The password hash for this user. password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account. upgraded to a non-guest account.

View File

@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
"id": "1234" "id": "1234"
} }
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=(user_id, token) return_value=user_id
) )
self.auth_handler.issue_access_token = Mock(return_value=token)
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200) self.assertEquals(code, 200)
det_data = { det_data = {
@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result) self.assertIn("refresh_token", result)
self.auth_handler.issue_access_token.assert_called_once_with(
user_id)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False