diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b219b46a4b..0041646858 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -43,6 +43,7 @@ class Codes(object): EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_IN_USE = "M_THREEPID_IN_USE" + THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" INVALID_USERNAME = "M_INVALID_USERNAME" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 9a84873a5f..47f78eba8c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -28,8 +28,40 @@ import logging logger = logging.getLogger(__name__) +class PasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + + def __init__(self, hs): + super(PasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is None: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password") + PATTERNS = client_v2_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -89,8 +121,40 @@ class PasswordRestServlet(RestServlet): return 200, {} +class ThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(ThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() @@ -157,5 +221,7 @@ class ThreepidRestServlet(RestServlet): def register_servlets(hs, http_server): + PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) + ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2088c316d1..7c6d2942dc 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,8 +41,40 @@ else: logger = logging.getLogger(__name__) +class RegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/email/requestToken$") + + def __init__(self, hs): + super(RegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register") + PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): super(RegisterRestServlet, self).__init__() @@ -70,10 +102,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - if '/register/email/requestToken' in request.path: - ret = yield self.onEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -305,29 +333,6 @@ class RegisterRestServlet(RestServlet): "refresh_token": refresh_token, }) - @defer.inlineCallbacks - def onEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) - - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) - - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) - - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] - ) - - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) - @defer.inlineCallbacks def _do_guest_registration(self): if not self.hs.config.allow_guest_access: @@ -345,4 +350,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): + RegisterRequestTokenRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)