diff --git a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewEvents.kt b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewEvents.kt index 5d6e7005c4..bf53a72cc3 100644 --- a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewEvents.kt +++ b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewEvents.kt @@ -20,7 +20,7 @@ package im.vector.app.features.onboarding import im.vector.app.core.platform.VectorViewEvents import im.vector.app.features.login.ServerType import im.vector.app.features.login.SignMode -import org.matrix.android.sdk.api.auth.registration.FlowResult +import org.matrix.android.sdk.api.auth.registration.Stage /** * Transient events for Login. @@ -30,7 +30,9 @@ sealed class OnboardingViewEvents : VectorViewEvents { data class Failure(val throwable: Throwable) : OnboardingViewEvents() data class DeeplinkAuthenticationFailure(val retryAction: OnboardingAction) : OnboardingViewEvents() - data class RegistrationFlowResult(val flowResult: FlowResult, val isRegistrationStarted: Boolean) : OnboardingViewEvents() + object DisplayRegistrationFallback : OnboardingViewEvents() + data class DisplayRegistrationStage(val stage: Stage) : OnboardingViewEvents() + object DisplayStartRegistration : OnboardingViewEvents() object OutdatedHomeserver : OnboardingViewEvents() // Navigation event diff --git a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt index 61877a5f47..19f6d226ca 100644 --- a/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt +++ b/vector/src/main/java/im/vector/app/features/onboarding/OnboardingViewModel.kt @@ -47,7 +47,6 @@ import im.vector.app.features.login.ServerType import im.vector.app.features.login.SignMode import im.vector.app.features.onboarding.OnboardingAction.AuthenticateAction import im.vector.app.features.onboarding.StartAuthenticationFlowUseCase.StartAuthenticationResult -import im.vector.app.features.onboarding.ftueauth.MatrixOrgRegistrationStagesComparator import kotlinx.coroutines.Job import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.launch @@ -56,9 +55,7 @@ import org.matrix.android.sdk.api.auth.HomeServerHistoryService import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig import org.matrix.android.sdk.api.auth.data.SsoIdentityProvider import org.matrix.android.sdk.api.auth.login.LoginWizard -import org.matrix.android.sdk.api.auth.registration.FlowResult import org.matrix.android.sdk.api.auth.registration.RegistrationWizard -import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.failure.isHomeserverUnavailable import org.matrix.android.sdk.api.session.Session import timber.log.Timber @@ -80,11 +77,11 @@ class OnboardingViewModel @AssistedInject constructor( private val vectorFeatures: VectorFeatures, private val analyticsTracker: AnalyticsTracker, private val uriFilenameResolver: UriFilenameResolver, - private val registrationActionHandler: RegistrationActionHandler, private val directLoginUseCase: DirectLoginUseCase, private val startAuthenticationFlowUseCase: StartAuthenticationFlowUseCase, private val vectorOverrides: VectorOverrides, - private val buildMeta: BuildMeta + private val registrationActionHandler: RegistrationActionHandler, + private val buildMeta: BuildMeta, ) : VectorViewModel(initialState) { @AssistedFactory @@ -150,18 +147,18 @@ class OnboardingViewModel @AssistedInject constructor( is OnboardingAction.WebLoginSuccess -> handleWebLoginSuccess(action) is OnboardingAction.ResetPassword -> handleResetPassword(action) is OnboardingAction.ResetPasswordMailConfirmed -> handleResetPasswordMailConfirmed() - is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction, ::emitFlowResultViewEvent) - is OnboardingAction.ResetAction -> handleResetAction(action) - is OnboardingAction.UserAcceptCertificate -> handleUserAcceptCertificate(action) - OnboardingAction.ClearHomeServerHistory -> handleClearHomeServerHistory() - is OnboardingAction.UpdateDisplayName -> updateDisplayName(action.displayName) - OnboardingAction.UpdateDisplayNameSkipped -> handleDisplayNameStepComplete() - OnboardingAction.UpdateProfilePictureSkipped -> completePersonalization() - OnboardingAction.PersonalizeProfile -> handlePersonalizeProfile() - is OnboardingAction.ProfilePictureSelected -> handleProfilePictureSelected(action) - OnboardingAction.SaveSelectedProfilePicture -> updateProfilePicture() - is OnboardingAction.PostViewEvent -> _viewEvents.post(action.viewEvent) - OnboardingAction.StopEmailValidationCheck -> cancelWaitForEmailValidation() + is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction) + is OnboardingAction.ResetAction -> handleResetAction(action) + is OnboardingAction.UserAcceptCertificate -> handleUserAcceptCertificate(action) + OnboardingAction.ClearHomeServerHistory -> handleClearHomeServerHistory() + is OnboardingAction.UpdateDisplayName -> updateDisplayName(action.displayName) + OnboardingAction.UpdateDisplayNameSkipped -> handleDisplayNameStepComplete() + OnboardingAction.UpdateProfilePictureSkipped -> completePersonalization() + OnboardingAction.PersonalizeProfile -> handlePersonalizeProfile() + is OnboardingAction.ProfilePictureSelected -> handleProfilePictureSelected(action) + OnboardingAction.SaveSelectedProfilePicture -> updateProfilePicture() + is OnboardingAction.PostViewEvent -> _viewEvents.post(action.viewEvent) + OnboardingAction.StopEmailValidationCheck -> cancelWaitForEmailValidation() } } @@ -259,12 +256,12 @@ class OnboardingViewModel @AssistedInject constructor( } } - private fun handleRegisterAction(action: RegisterAction, onNextRegistrationStepAction: (FlowResult) -> Unit) { + private fun handleRegisterAction(action: RegisterAction) { val job = viewModelScope.launch { if (action.hasLoadingState()) { setState { copy(isLoading = true) } } - internalRegisterAction(action, onNextRegistrationStepAction) + internalRegisterAction(action) setState { copy(isLoading = false) } } @@ -275,23 +272,28 @@ class OnboardingViewModel @AssistedInject constructor( } } - private suspend fun internalRegisterAction(action: RegisterAction, onNextRegistrationStepAction: (FlowResult) -> Unit) { - runCatching { registrationActionHandler.handleRegisterAction(registrationWizard, action) } + private suspend fun internalRegisterAction(action: RegisterAction, overrideNextStage: (() -> Unit)? = null) { + runCatching { registrationActionHandler.processAction(awaitState().selectedHomeserver, action) } .fold( onSuccess = { - when { - action.ignoresResult() -> { + when (it) { + RegistrationActionHandler.Result.Ignored -> { // do nothing } - else -> when (it) { - is RegistrationResult.Complete -> onSessionCreated( - it.session, - authenticationDescription = awaitState().selectedAuthenticationState.description - ?: AuthenticationDescription.Register(AuthenticationDescription.AuthenticationType.Other) - ) - is RegistrationResult.NextStep -> onFlowResponse(it.flowResult, onNextRegistrationStepAction) - is RegistrationResult.SendEmailSuccess -> _viewEvents.post(OnboardingViewEvents.OnSendEmailSuccess(it.email)) - is RegistrationResult.Error -> _viewEvents.post(OnboardingViewEvents.Failure(it.cause)) + is RegistrationActionHandler.Result.NextStage -> { + overrideNextStage?.invoke() ?: _viewEvents.post(OnboardingViewEvents.DisplayRegistrationStage(it.stage)) + } + is RegistrationActionHandler.Result.Success -> onSessionCreated( + it.session, + authenticationDescription = awaitState().selectedAuthenticationState.description + ?: AuthenticationDescription.Register(AuthenticationDescription.AuthenticationType.Other) + ) + RegistrationActionHandler.Result.StartRegistration -> _viewEvents.post(OnboardingViewEvents.DisplayStartRegistration) + RegistrationActionHandler.Result.UnsupportedStage -> _viewEvents.post(OnboardingViewEvents.DisplayRegistrationFallback) + is RegistrationActionHandler.Result.SendEmailSuccess -> _viewEvents.post(OnboardingViewEvents.OnSendEmailSuccess(it.email)) + is RegistrationActionHandler.Result.Error -> _viewEvents.post(OnboardingViewEvents.Failure(it.cause)) + RegistrationActionHandler.Result.MissingNextStage -> { + _viewEvents.post(OnboardingViewEvents.Failure(IllegalStateException("No next registration stage found"))) } } }, @@ -303,18 +305,6 @@ class OnboardingViewModel @AssistedInject constructor( ) } - private fun emitFlowResultViewEvent(flowResult: FlowResult) { - withState { state -> - val orderedResult = when { - state.hasSelectedMatrixOrg() && vectorFeatures.isOnboardingCombinedRegisterEnabled() -> flowResult.copy( - missingStages = flowResult.missingStages.sortedWith(MatrixOrgRegistrationStagesComparator()) - ) - else -> flowResult - } - _viewEvents.post(OnboardingViewEvents.RegistrationFlowResult(orderedResult, isRegistrationStarted)) - } - } - private fun OnboardingViewState.hasSelectedMatrixOrg() = selectedHomeserver.userFacingUrl == matrixOrgUrl private fun handleRegisterWith(action: AuthenticateAction.Register) { @@ -328,8 +318,7 @@ class OnboardingViewModel @AssistedInject constructor( action.username, action.password, action.initialDeviceName - ), - ::emitFlowResultViewEvent + ) ) } @@ -382,8 +371,8 @@ class OnboardingViewModel @AssistedInject constructor( private fun handleUpdateSignMode(action: OnboardingAction.UpdateSignMode) { updateSignMode(action.signMode) when (action.signMode) { - SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration, ::emitFlowResultViewEvent) - SignMode.SignIn -> startAuthenticationFlow() + SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration) + SignMode.SignIn -> startAuthenticationFlow() SignMode.SignInWithMatrixId -> _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignInWithMatrixId)) SignMode.Unknown -> Unit } @@ -530,19 +519,6 @@ class OnboardingViewModel @AssistedInject constructor( _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignIn)) } - private suspend fun onFlowResponse(flowResult: FlowResult, onNextRegistrationStepAction: (FlowResult) -> Unit) { - // If dummy stage is mandatory, and password is already sent, do the dummy stage now - if (isRegistrationStarted && flowResult.missingStages.any { it is Stage.Dummy && it.mandatory }) { - handleRegisterDummy(onNextRegistrationStepAction) - } else { - onNextRegistrationStepAction(flowResult) - } - } - - private suspend fun handleRegisterDummy(onNextRegistrationStepAction: (FlowResult) -> Unit) { - internalRegisterAction(RegisterAction.RegisterDummy, onNextRegistrationStepAction) - } - private suspend fun onSessionCreated(session: Session, authenticationDescription: AuthenticationDescription) { val state = awaitState() state.useCase?.let { useCase -> @@ -684,7 +660,7 @@ class OnboardingViewModel @AssistedInject constructor( } OnboardingFlow.SignUp -> { updateSignMode(SignMode.SignUp) - internalRegisterAction(RegisterAction.StartRegistration, ::emitFlowResultViewEvent) + internalRegisterAction(RegisterAction.StartRegistration) } OnboardingFlow.SignInSignUp, null -> { diff --git a/vector/src/main/java/im/vector/app/features/onboarding/RegistrationActionHandler.kt b/vector/src/main/java/im/vector/app/features/onboarding/RegistrationActionHandler.kt index 3c3ac95cf2..9520413cd8 100644 --- a/vector/src/main/java/im/vector/app/features/onboarding/RegistrationActionHandler.kt +++ b/vector/src/main/java/im/vector/app/features/onboarding/RegistrationActionHandler.kt @@ -16,105 +16,91 @@ package im.vector.app.features.onboarding +import im.vector.app.R +import im.vector.app.core.resources.StringProvider +import im.vector.app.core.utils.ensureTrailingSlash +import im.vector.app.features.VectorFeatures +import im.vector.app.features.VectorOverrides +import im.vector.app.features.login.isSupported +import im.vector.app.features.onboarding.ftueauth.MatrixOrgRegistrationStagesComparator +import kotlinx.coroutines.flow.first +import org.matrix.android.sdk.api.auth.AuthenticationService import org.matrix.android.sdk.api.auth.registration.FlowResult -import org.matrix.android.sdk.api.auth.registration.RegisterThreePid -import org.matrix.android.sdk.api.auth.registration.RegistrationResult.FlowResponse -import org.matrix.android.sdk.api.auth.registration.RegistrationResult.Success -import org.matrix.android.sdk.api.auth.registration.RegistrationWizard -import org.matrix.android.sdk.api.failure.is401 +import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.session.Session import javax.inject.Inject -import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult -class RegistrationActionHandler @Inject constructor() { +class RegistrationActionHandler @Inject constructor( + private val registrationWizardActionDelegate: RegistrationWizardActionDelegate, + private val authenticationService: AuthenticationService, + private val vectorOverrides: VectorOverrides, + private val vectorFeatures: VectorFeatures, + stringProvider: StringProvider +) { - suspend fun handleRegisterAction(registrationWizard: RegistrationWizard, action: RegisterAction): RegistrationResult { - return when (action) { - RegisterAction.StartRegistration -> resultOf { registrationWizard.getRegistrationFlow() } - is RegisterAction.CaptchaDone -> resultOf { registrationWizard.performReCaptcha(action.captchaResponse) } - is RegisterAction.AcceptTerms -> resultOf { registrationWizard.acceptTerms() } - is RegisterAction.RegisterDummy -> resultOf { registrationWizard.dummy() } - is RegisterAction.AddThreePid -> handleAddThreePid(registrationWizard, action) - is RegisterAction.SendAgainThreePid -> resultOf { registrationWizard.sendAgainThreePid() } - is RegisterAction.ValidateThreePid -> resultOf { registrationWizard.handleValidateThreePid(action.code) } - is RegisterAction.CheckIfEmailHasBeenValidated -> handleCheckIfEmailIsValidated(registrationWizard, action.delayMillis) - is RegisterAction.CreateAccount -> resultOf { - registrationWizard.createAccount( - action.username, - action.password, - action.initialDeviceName - ) + private val matrixOrgUrl = stringProvider.getString(R.string.matrix_org_server_url).ensureTrailingSlash() + + suspend fun processAction(state: SelectedHomeserverState, action: RegisterAction): Result { + val result = registrationWizardActionDelegate.executeAction(action) + return when { + action.ignoresResult() -> Result.Ignored + else -> when (result) { + is RegistrationResult.Complete -> Result.Success(result.session) + is RegistrationResult.NextStep -> processFlowResult(result, state) + is RegistrationResult.SendEmailSuccess -> Result.SendEmailSuccess(result.email) + is RegistrationResult.Error -> Result.Error(result.cause) } } } - private suspend fun handleAddThreePid(wizard: RegistrationWizard, action: RegisterAction.AddThreePid): RegistrationResult { - return runCatching { wizard.addThreePid(action.threePid) }.fold( - onSuccess = { it.toRegistrationResult() }, - onFailure = { - when { - action.threePid is RegisterThreePid.Email && it.is401() -> RegistrationResult.SendEmailSuccess(action.threePid.email) - else -> RegistrationResult.Error(it) - } - } - ) + private suspend fun processFlowResult(result: RegistrationResult.NextStep, state: SelectedHomeserverState): Result { + // If dummy stage is mandatory, and password is already sent, do the dummy stage now + return if (authenticationService.isRegistrationStarted() && result.flowResult.missingStages.hasMandatoryDummy()) { + processAction(state, RegisterAction.RegisterDummy) + } else { + handleNextStep(state, result.flowResult) + } } - private tailrec suspend fun handleCheckIfEmailIsValidated(registrationWizard: RegistrationWizard, delayMillis: Long): RegistrationResult { - return runCatching { registrationWizard.checkIfEmailHasBeenValidated(delayMillis) }.fold( - onSuccess = { it.toRegistrationResult() }, - onFailure = { - when { - it.is401() -> null // recursively continue to check with a delay - else -> RegistrationResult.Error(it) - } - } - ) ?: handleCheckIfEmailIsValidated(registrationWizard, 10_000) + private suspend fun handleNextStep(state: SelectedHomeserverState, flowResult: FlowResult): Result { + return when { + flowResult.registrationShouldFallback() -> Result.UnsupportedStage + authenticationService.isRegistrationStarted() -> findNextStage(state, flowResult) + else -> Result.StartRegistration + } + } + + private fun findNextStage(state: SelectedHomeserverState, flowResult: FlowResult): Result { + val orderedResult = when { + state.hasSelectedMatrixOrg() && vectorFeatures.isOnboardingCombinedRegisterEnabled() -> flowResult.copy( + missingStages = flowResult.missingStages.sortedWith(MatrixOrgRegistrationStagesComparator()) + ) + else -> flowResult + } + return orderedResult.findNextRegistrationStage() + ?.let { Result.NextStage(it) } + ?: Result.MissingNextStage + } + + private fun FlowResult.findNextRegistrationStage() = missingStages.firstMandatoryOrNull() ?: missingStages.ignoreDummy().firstOptionalOrNull() + + private suspend fun FlowResult.registrationShouldFallback() = vectorOverrides.forceLoginFallback.first() || missingStages.any { !it.isSupported() } + + private fun SelectedHomeserverState.hasSelectedMatrixOrg() = userFacingUrl == matrixOrgUrl + + sealed interface Result { + data class Success(val session: Session) : Result + data class NextStage(val stage: Stage) : Result + data class Error(val cause: Throwable) : Result + data class SendEmailSuccess(val email: String) : Result + object MissingNextStage : Result + object StartRegistration : Result + object UnsupportedStage : Result + object Ignored : Result } } -private inline fun resultOf(block: () -> MatrixRegistrationResult): RegistrationResult { - return runCatching { block() }.fold( - onSuccess = { it.toRegistrationResult() }, - onFailure = { RegistrationResult.Error(it) } - ) -} - -private fun MatrixRegistrationResult.toRegistrationResult() = when (this) { - is FlowResponse -> RegistrationResult.NextStep(flowResult) - is Success -> RegistrationResult.Complete(session) -} - -sealed interface RegistrationResult { - data class Error(val cause: Throwable) : RegistrationResult - data class Complete(val session: Session) : RegistrationResult - data class NextStep(val flowResult: FlowResult) : RegistrationResult - data class SendEmailSuccess(val email: String) : RegistrationResult -} - -sealed interface RegisterAction { - object StartRegistration : RegisterAction - data class CreateAccount(val username: String, val password: String, val initialDeviceName: String) : RegisterAction - - data class AddThreePid(val threePid: RegisterThreePid) : RegisterAction - object SendAgainThreePid : RegisterAction - - // TODO Confirm Email (from link in the email, open in the phone, intercepted by the app) - data class ValidateThreePid(val code: String) : RegisterAction - - data class CheckIfEmailHasBeenValidated(val delayMillis: Long) : RegisterAction - - data class CaptchaDone(val captchaResponse: String) : RegisterAction - object AcceptTerms : RegisterAction - object RegisterDummy : RegisterAction -} - -fun RegisterAction.ignoresResult() = when (this) { - is RegisterAction.SendAgainThreePid -> true - else -> false -} - -fun RegisterAction.hasLoadingState() = when (this) { - is RegisterAction.CheckIfEmailHasBeenValidated -> false - else -> true -} +private fun List.firstMandatoryOrNull() = firstOrNull { it.mandatory } +private fun List.firstOptionalOrNull() = firstOrNull { !it.mandatory } +private fun List.ignoreDummy() = filter { it !is Stage.Dummy } +private fun List.hasMandatoryDummy() = any { it is Stage.Dummy && it.mandatory } diff --git a/vector/src/main/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegate.kt b/vector/src/main/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegate.kt new file mode 100644 index 0000000000..5ce8bb857b --- /dev/null +++ b/vector/src/main/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegate.kt @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.features.onboarding + +import org.matrix.android.sdk.api.auth.AuthenticationService +import org.matrix.android.sdk.api.auth.registration.FlowResult +import org.matrix.android.sdk.api.auth.registration.RegisterThreePid +import org.matrix.android.sdk.api.auth.registration.RegistrationResult.FlowResponse +import org.matrix.android.sdk.api.auth.registration.RegistrationResult.Success +import org.matrix.android.sdk.api.auth.registration.RegistrationWizard +import org.matrix.android.sdk.api.failure.is401 +import org.matrix.android.sdk.api.session.Session +import javax.inject.Inject +import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult + +class RegistrationWizardActionDelegate @Inject constructor( + private val authenticationService: AuthenticationService +) { + + private val registrationWizard: RegistrationWizard + get() = authenticationService.getRegistrationWizard() + + suspend fun executeAction(action: RegisterAction): RegistrationResult { + return when (action) { + RegisterAction.StartRegistration -> resultOf { registrationWizard.getRegistrationFlow() } + is RegisterAction.CaptchaDone -> resultOf { registrationWizard.performReCaptcha(action.captchaResponse) } + is RegisterAction.AcceptTerms -> resultOf { registrationWizard.acceptTerms() } + is RegisterAction.RegisterDummy -> resultOf { registrationWizard.dummy() } + is RegisterAction.AddThreePid -> handleAddThreePid(registrationWizard, action) + is RegisterAction.SendAgainThreePid -> resultOf { registrationWizard.sendAgainThreePid() } + is RegisterAction.ValidateThreePid -> resultOf { registrationWizard.handleValidateThreePid(action.code) } + is RegisterAction.CheckIfEmailHasBeenValidated -> handleCheckIfEmailIsValidated(registrationWizard, action.delayMillis) + is RegisterAction.CreateAccount -> resultOf { + registrationWizard.createAccount( + action.username, + action.password, + action.initialDeviceName + ) + } + } + } + + private suspend fun handleAddThreePid(wizard: RegistrationWizard, action: RegisterAction.AddThreePid): RegistrationResult { + return runCatching { wizard.addThreePid(action.threePid) }.fold( + onSuccess = { it.toRegistrationResult() }, + onFailure = { + when { + action.threePid is RegisterThreePid.Email && it.is401() -> RegistrationResult.SendEmailSuccess(action.threePid.email) + else -> RegistrationResult.Error(it) + } + } + ) + } + + private tailrec suspend fun handleCheckIfEmailIsValidated(registrationWizard: RegistrationWizard, delayMillis: Long): RegistrationResult { + return runCatching { registrationWizard.checkIfEmailHasBeenValidated(delayMillis) }.fold( + onSuccess = { it.toRegistrationResult() }, + onFailure = { + when { + it.is401() -> null // recursively continue to check with a delay + else -> RegistrationResult.Error(it) + } + } + ) ?: handleCheckIfEmailIsValidated(registrationWizard, 10_000) + } +} + +private inline fun resultOf(block: () -> MatrixRegistrationResult): RegistrationResult { + return runCatching { block() }.fold( + onSuccess = { it.toRegistrationResult() }, + onFailure = { RegistrationResult.Error(it) } + ) +} + +private fun MatrixRegistrationResult.toRegistrationResult() = when (this) { + is FlowResponse -> RegistrationResult.NextStep(flowResult) + is Success -> RegistrationResult.Complete(session) +} + +sealed interface RegistrationResult { + data class Error(val cause: Throwable) : RegistrationResult + data class Complete(val session: Session) : RegistrationResult + data class NextStep(val flowResult: FlowResult) : RegistrationResult + data class SendEmailSuccess(val email: String) : RegistrationResult +} + +sealed interface RegisterAction { + object StartRegistration : RegisterAction + data class CreateAccount(val username: String, val password: String, val initialDeviceName: String) : RegisterAction + + data class AddThreePid(val threePid: RegisterThreePid) : RegisterAction + object SendAgainThreePid : RegisterAction + data class ValidateThreePid(val code: String) : RegisterAction + data class CheckIfEmailHasBeenValidated(val delayMillis: Long) : RegisterAction + + data class CaptchaDone(val captchaResponse: String) : RegisterAction + object AcceptTerms : RegisterAction + object RegisterDummy : RegisterAction +} + +fun RegisterAction.ignoresResult() = when (this) { + is RegisterAction.SendAgainThreePid -> true + else -> false +} + +fun RegisterAction.hasLoadingState() = when (this) { + is RegisterAction.CheckIfEmailHasBeenValidated -> false + else -> true +} diff --git a/vector/src/main/java/im/vector/app/features/onboarding/ftueauth/FtueAuthVariant.kt b/vector/src/main/java/im/vector/app/features/onboarding/ftueauth/FtueAuthVariant.kt index f8ad700b40..89e28740a4 100644 --- a/vector/src/main/java/im/vector/app/features/onboarding/ftueauth/FtueAuthVariant.kt +++ b/vector/src/main/java/im/vector/app/features/onboarding/ftueauth/FtueAuthVariant.kt @@ -44,7 +44,6 @@ import im.vector.app.features.login.LoginMode import im.vector.app.features.login.ServerType import im.vector.app.features.login.SignMode import im.vector.app.features.login.TextInputFormFragmentMode -import im.vector.app.features.login.isSupported import im.vector.app.features.onboarding.OnboardingAction import im.vector.app.features.onboarding.OnboardingActivity import im.vector.app.features.onboarding.OnboardingVariant @@ -129,10 +128,7 @@ class FtueAuthVariant( private fun handleOnboardingViewEvents(viewEvents: OnboardingViewEvents) { when (viewEvents) { - is OnboardingViewEvents.RegistrationFlowResult -> { - onRegistrationFlow(viewEvents) - } - is OnboardingViewEvents.OutdatedHomeserver -> { + is OnboardingViewEvents.OutdatedHomeserver -> { MaterialAlertDialogBuilder(activity) .setTitle(R.string.login_error_outdated_homeserver_title) .setMessage(R.string.login_error_outdated_homeserver_warning_content) @@ -227,9 +223,15 @@ class FtueAuthVariant( option = commonOption ) } - OnboardingViewEvents.OnHomeserverEdited -> activity.popBackstack() - OnboardingViewEvents.OpenCombinedLogin -> onStartCombinedLogin() - is OnboardingViewEvents.DeeplinkAuthenticationFailure -> onDeeplinkedHomeserverUnavailable(viewEvents) + OnboardingViewEvents.OnHomeserverEdited -> activity.popBackstack() + OnboardingViewEvents.OpenCombinedLogin -> onStartCombinedLogin() + is OnboardingViewEvents.DeeplinkAuthenticationFailure -> onDeeplinkedHomeserverUnavailable(viewEvents) + OnboardingViewEvents.DisplayRegistrationFallback -> displayFallbackWebDialog() + is OnboardingViewEvents.DisplayRegistrationStage -> doStage(viewEvents.stage) + OnboardingViewEvents.DisplayStartRegistration -> when { + vectorFeatures.isOnboardingCombinedRegisterEnabled() -> openStartCombinedRegister() + else -> openAuthLoginFragmentWithTag(FRAGMENT_REGISTRATION_STAGE_TAG) + } } } @@ -253,25 +255,10 @@ class FtueAuthVariant( addRegistrationStageFragmentToBackstack(FtueAuthCombinedLoginFragment::class.java) } - private fun onRegistrationFlow(viewEvents: OnboardingViewEvents.RegistrationFlowResult) { - when { - registrationShouldFallback(viewEvents) -> displayFallbackWebDialog() - viewEvents.isRegistrationStarted -> handleRegistrationNavigation(viewEvents.flowResult.missingStages) - vectorFeatures.isOnboardingCombinedRegisterEnabled() -> openStartCombinedRegister() - else -> openAuthLoginFragmentWithTag(FRAGMENT_REGISTRATION_STAGE_TAG) - } - } - private fun openStartCombinedRegister() { addRegistrationStageFragmentToBackstack(FtueAuthCombinedRegisterFragment::class.java) } - private fun registrationShouldFallback(registrationFlowResult: OnboardingViewEvents.RegistrationFlowResult) = - isForceLoginFallbackEnabled || registrationFlowResult.containsUnsupportedRegistrationFlow() - - private fun OnboardingViewEvents.RegistrationFlowResult.containsUnsupportedRegistrationFlow() = - flowResult.missingStages.any { !it.isSupported() } - private fun displayFallbackWebDialog() { MaterialAlertDialogBuilder(activity) .setTitle(R.string.app_name) @@ -381,23 +368,6 @@ class FtueAuthVariant( ?.let { onboardingViewModel.handle(OnboardingAction.LoginWithToken(it)) } } - private fun handleRegistrationNavigation(remainingStages: List) { - // Complete all mandatory stages first - val mandatoryStage = remainingStages.firstOrNull { it.mandatory } - - if (mandatoryStage != null) { - doStage(mandatoryStage) - } else { - // Consider optional stages - val optionalStage = remainingStages.firstOrNull { !it.mandatory && it !is Stage.Dummy } - if (optionalStage == null) { - // Should not happen... - } else { - doStage(optionalStage) - } - } - } - private fun doStage(stage: Stage) { // Ensure there is no fragment for registration stage in the backstack supportFragmentManager.popBackStack(FRAGMENT_REGISTRATION_STAGE_TAG, FragmentManager.POP_BACK_STACK_INCLUSIVE) diff --git a/vector/src/test/java/im/vector/app/features/onboarding/OnboardingViewModelTest.kt b/vector/src/test/java/im/vector/app/features/onboarding/OnboardingViewModelTest.kt index 77539da232..ce32dfeb3d 100644 --- a/vector/src/test/java/im/vector/app/features/onboarding/OnboardingViewModelTest.kt +++ b/vector/src/test/java/im/vector/app/features/onboarding/OnboardingViewModelTest.kt @@ -31,9 +31,8 @@ import im.vector.app.test.fakes.FakeContext import im.vector.app.test.fakes.FakeDirectLoginUseCase import im.vector.app.test.fakes.FakeHomeServerConnectionConfigFactory import im.vector.app.test.fakes.FakeHomeServerHistoryService +import im.vector.app.test.fakes.FakeRegistrationActionHandler import im.vector.app.test.fakes.FakeLoginWizard -import im.vector.app.test.fakes.FakeRegisterActionHandler -import im.vector.app.test.fakes.FakeRegistrationWizard import im.vector.app.test.fakes.FakeSession import im.vector.app.test.fakes.FakeStartAuthenticationFlowUseCase import im.vector.app.test.fakes.FakeStringProvider @@ -50,7 +49,6 @@ import org.junit.Before import org.junit.Rule import org.junit.Test import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig -import org.matrix.android.sdk.api.auth.registration.FlowResult import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.session.Session import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilities @@ -62,8 +60,7 @@ private val A_LOADABLE_REGISTER_ACTION = RegisterAction.StartRegistration private val A_NON_LOADABLE_REGISTER_ACTION = RegisterAction.CheckIfEmailHasBeenValidated(delayMillis = -1L) private val A_RESULT_IGNORED_REGISTER_ACTION = RegisterAction.SendAgainThreePid private val A_HOMESERVER_CAPABILITIES = aHomeServerCapabilities(canChangeDisplayName = true, canChangeAvatar = true) -private val AN_IGNORED_FLOW_RESULT = FlowResult(missingStages = emptyList(), completedStages = emptyList()) -private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT) +private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationActionHandler.Result.NextStage(Stage.Dummy(mandatory = true)) private val A_DIRECT_LOGIN = OnboardingAction.AuthenticateAction.LoginDirect("@a-user:id.org", "a-password", "a-device-name") private const val A_HOMESERVER_URL = "https://edited-homeserver.org" private val A_HOMESERVER_CONFIG = HomeServerConnectionConfig(FakeUri().instance) @@ -82,7 +79,7 @@ class OnboardingViewModelTest { private val fakeUriFilenameResolver = FakeUriFilenameResolver() private val fakeActiveSessionHolder = FakeActiveSessionHolder(fakeSession) private val fakeAuthenticationService = FakeAuthenticationService() - private val fakeRegisterActionHandler = FakeRegisterActionHandler() + private val fakeRegistrationActionHandler = FakeRegistrationActionHandler() private val fakeDirectLoginUseCase = FakeDirectLoginUseCase() private val fakeVectorFeatures = FakeVectorFeatures() private val fakeHomeServerConnectionConfigFactory = FakeHomeServerConnectionConfigFactory() @@ -199,7 +196,7 @@ class OnboardingViewModelTest { { copy(isLoading = true) }, { copy(isLoading = false) } ) - .assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) + .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage)) .finish() } @@ -216,7 +213,7 @@ class OnboardingViewModelTest { { copy(isLoading = true) }, { copy(isLoading = false) } ) - .assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) + .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage)) .finish() } @@ -229,14 +226,14 @@ class OnboardingViewModelTest { test .assertState(initialState) - .assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) + .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage)) .finish() } @Test fun `given register action ignores result, when handling action, then does nothing on success`() = runTest { val test = viewModel.test() - givenRegistrationResultFor(A_RESULT_IGNORED_REGISTER_ACTION, RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT)) + givenRegistrationResultFor(A_RESULT_IGNORED_REGISTER_ACTION, RegistrationActionHandler.Result.Ignored) viewModel.handle(OnboardingAction.PostRegisterAction(A_RESULT_IGNORED_REGISTER_ACTION)) @@ -276,7 +273,7 @@ class OnboardingViewModelTest { viewModelWith(initialState.copy(onboardingFlow = OnboardingFlow.SignUp)) fakeHomeServerConnectionConfigFactory.givenConfigFor(A_HOMESERVER_URL, A_HOMESERVER_CONFIG) fakeStartAuthenticationFlowUseCase.givenResult(A_HOMESERVER_CONFIG, StartAuthenticationResult(isHomeserverOutdated = false, SELECTED_HOMESERVER_STATE)) - givenRegistrationResultFor(RegisterAction.StartRegistration, RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT)) + givenRegistrationResultFor(RegisterAction.StartRegistration, ANY_CONTINUING_REGISTRATION_RESULT) fakeHomeServerHistoryService.expectUrlToBeAdded(A_HOMESERVER_CONFIG.homeServerUri.toString()) val test = viewModel.test() @@ -318,11 +315,11 @@ class OnboardingViewModelTest { @Test fun `given personalisation enabled, when registering account, then updates state and emits account created event`() = runTest { fakeVectorFeatures.givenPersonalisationEnabled() - givenRegistrationResultFor(A_LOADABLE_REGISTER_ACTION, RegistrationResult.Complete(fakeSession)) givenSuccessfullyCreatesAccount(A_HOMESERVER_CAPABILITIES) + givenRegistrationResultFor(RegisterAction.StartRegistration, RegistrationActionHandler.Result.Success(fakeSession)) val test = viewModel.test() - viewModel.handle(OnboardingAction.PostRegisterAction(A_LOADABLE_REGISTER_ACTION)) + viewModel.handle(OnboardingAction.PostRegisterAction(RegisterAction.StartRegistration)) test .assertStatesChanges( @@ -334,26 +331,6 @@ class OnboardingViewModelTest { .finish() } - @Test - fun `given personalisation enabled and registration has started and has dummy step to do, when handling action, then ignores other steps and does dummy`() { - runTest { - fakeVectorFeatures.givenPersonalisationEnabled() - givenSuccessfulRegistrationForStartAndDummySteps(missingStages = listOf(Stage.Dummy(mandatory = true))) - val test = viewModel.test() - - viewModel.handle(OnboardingAction.PostRegisterAction(A_LOADABLE_REGISTER_ACTION)) - - test - .assertStatesChanges( - initialState, - { copy(isLoading = true) }, - { copy(isLoading = false, personalizationState = A_HOMESERVER_CAPABILITIES.toPersonalisationState()) } - ) - .assertEvents(OnboardingViewEvents.OnAccountCreated) - .finish() - } - } - @Test fun `given changing profile avatar is supported, when updating display name, then updates upstream user display name and moves to choose profile avatar`() { runTest { @@ -520,11 +497,11 @@ class OnboardingViewModelTest { fakeVectorFeatures, FakeAnalyticsTracker(), fakeUriFilenameResolver.instance, - fakeRegisterActionHandler.instance, fakeDirectLoginUseCase.instance, fakeStartAuthenticationFlowUseCase.instance, FakeVectorOverrides(), - aBuildMeta() + fakeRegistrationActionHandler.instance, + aBuildMeta(), ).also { viewModel = it initialState = state @@ -556,17 +533,6 @@ class OnboardingViewModelTest { ) } - private fun givenSuccessfulRegistrationForStartAndDummySteps(missingStages: List) { - val flowResult = FlowResult(missingStages = missingStages, completedStages = emptyList()) - givenRegistrationResultsFor( - listOf( - A_LOADABLE_REGISTER_ACTION to RegistrationResult.NextStep(flowResult), - RegisterAction.RegisterDummy to RegistrationResult.Complete(fakeSession) - ) - ) - givenSuccessfullyCreatesAccount(A_HOMESERVER_CAPABILITIES) - } - private fun givenSuccessfullyCreatesAccount(homeServerCapabilities: HomeServerCapabilities) { fakeSession.fakeHomeServerCapabilitiesService.givenCapabilities(homeServerCapabilities) givenInitialisesSession(fakeSession) @@ -578,21 +544,16 @@ class OnboardingViewModelTest { fakeSession.expectStartsSyncing() } - private fun givenRegistrationResultFor(action: RegisterAction, result: RegistrationResult) { + private fun givenRegistrationResultFor(action: RegisterAction, result: RegistrationActionHandler.Result) { givenRegistrationResultsFor(listOf(action to result)) } - private fun givenRegistrationResultsFor(results: List>) { - fakeAuthenticationService.givenRegistrationStarted(true) - val registrationWizard = FakeRegistrationWizard() - fakeAuthenticationService.givenRegistrationWizard(registrationWizard) - fakeRegisterActionHandler.givenResultsFor(registrationWizard, results) + private fun givenRegistrationResultsFor(results: List>) { + fakeRegistrationActionHandler.givenResultsFor(results) } private fun givenRegistrationActionErrors(action: RegisterAction, cause: Throwable) { - val registrationWizard = FakeRegistrationWizard() - fakeAuthenticationService.givenRegistrationWizard(registrationWizard) - fakeRegisterActionHandler.givenThrowsFor(registrationWizard, action, cause) + fakeRegistrationActionHandler.givenThrows(action, cause) } } diff --git a/vector/src/test/java/im/vector/app/features/onboarding/RegistrationActionHandlerTest.kt b/vector/src/test/java/im/vector/app/features/onboarding/RegistrationActionHandlerTest.kt index f6d9317038..ffb1911b20 100644 --- a/vector/src/test/java/im/vector/app/features/onboarding/RegistrationActionHandlerTest.kt +++ b/vector/src/test/java/im/vector/app/features/onboarding/RegistrationActionHandlerTest.kt @@ -16,108 +16,166 @@ package im.vector.app.features.onboarding -import im.vector.app.test.fakes.FakeRegistrationWizard +import im.vector.app.R +import im.vector.app.test.fixtures.SelectedHomeserverStateFixture.aSelectedHomeserverState +import im.vector.app.test.fakes.FakeAuthenticationService +import im.vector.app.test.fakes.FakeRegistrationWizardActionDelegate import im.vector.app.test.fakes.FakeSession -import im.vector.app.test.fixtures.a401ServerError -import io.mockk.coVerifyAll +import im.vector.app.test.fakes.FakeStringProvider +import im.vector.app.test.fakes.FakeVectorFeatures +import im.vector.app.test.fakes.FakeVectorOverrides import kotlinx.coroutines.test.runTest import org.amshove.kluent.shouldBeEqualTo import org.junit.Test +import org.matrix.android.sdk.api.auth.registration.FlowResult import org.matrix.android.sdk.api.auth.registration.RegisterThreePid -import org.matrix.android.sdk.api.auth.registration.RegistrationWizard -import org.matrix.android.sdk.api.auth.registration.RegistrationResult as SdkResult +import org.matrix.android.sdk.api.auth.registration.Stage -private const val IGNORED_DELAY = 0L -private val AN_ERROR = RuntimeException() private val A_SESSION = FakeSession() -private val AN_EXPECTED_RESULT = RegistrationResult.Complete(A_SESSION) -private const val A_USERNAME = "a username" -private const val A_PASSWORD = "a password" -private const val AN_INITIAL_DEVICE_NAME = "a device name" -private const val A_CAPTCHA_RESPONSE = "a captcha response" -private const val A_PID_CODE = "a pid code" -private const val EMAIL_VALIDATED_DELAY = 10000L -private val A_PID_TO_REGISTER = RegisterThreePid.Email("an email") class RegistrationActionHandlerTest { - private val fakeRegistrationWizard = FakeRegistrationWizard() - private val registrationActionHandler = RegistrationActionHandler() + private val fakeWizardActionDelegate = FakeRegistrationWizardActionDelegate() + private val fakeAuthenticationService = FakeAuthenticationService() + private val vectorOverrides = FakeVectorOverrides() + private val vectorFeatures = FakeVectorFeatures() + private val fakeStringProvider = FakeStringProvider().also { + it.given(R.string.matrix_org_server_url, "https://matrix.org") + } + + private val registrationActionHandler = RegistrationActionHandler( + fakeWizardActionDelegate.instance, + fakeAuthenticationService, + vectorOverrides, + vectorFeatures, + fakeStringProvider.instance + ) @Test - fun `when handling register action then delegates to wizard`() = runTest { - val cases = listOf( - case(RegisterAction.StartRegistration) { getRegistrationFlow() }, - case(RegisterAction.CaptchaDone(A_CAPTCHA_RESPONSE)) { performReCaptcha(A_CAPTCHA_RESPONSE) }, - case(RegisterAction.AcceptTerms) { acceptTerms() }, - case(RegisterAction.RegisterDummy) { dummy() }, - case(RegisterAction.AddThreePid(A_PID_TO_REGISTER)) { addThreePid(A_PID_TO_REGISTER) }, - case(RegisterAction.SendAgainThreePid) { sendAgainThreePid() }, - case(RegisterAction.ValidateThreePid(A_PID_CODE)) { handleValidateThreePid(A_PID_CODE) }, - case(RegisterAction.CheckIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY)) { checkIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY) }, - case(RegisterAction.CreateAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)) { - createAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME) - } - ) + fun `when processing SendAgainThreePid, then ignores result`() = runTest { + val sendAgainThreePid = RegisterAction.SendAgainThreePid + fakeWizardActionDelegate.givenResultsFor(listOf(sendAgainThreePid to RegistrationResult.Complete(A_SESSION))) - cases.forEach { testSuccessfulActionDelegation(it) } + val result = registrationActionHandler.processAction(sendAgainThreePid) + + result shouldBeEqualTo RegistrationActionHandler.Result.Ignored } @Test - fun `given adding an email ThreePid fails with 401, when handling register action, then infer EmailSuccess`() = runTest { - fakeRegistrationWizard.givenAddEmailThreePidErrors( - cause = a401ServerError(), - email = A_PID_TO_REGISTER.email - ) + fun `given wizard delegate returns success, when handling action, then returns success`() = runTest { + fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to RegistrationResult.Complete(A_SESSION))) - val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.AddThreePid(A_PID_TO_REGISTER)) + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) - result shouldBeEqualTo RegistrationResult.SendEmailSuccess(A_PID_TO_REGISTER.email) + result shouldBeEqualTo RegistrationActionHandler.Result.Success(A_SESSION) } @Test - fun `given email verification errors with 401 then fatal error, when checking email validation, then continues to poll until non 401 error`() = runTest { - val errorsToThrow = listOf( - a401ServerError(), - a401ServerError(), - a401ServerError(), - AN_ERROR - ) - fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow) + fun `given flow result contains unsupported stages, when handling action, then returns UnsupportedStage`() = runTest { + fakeAuthenticationService.givenRegistrationStarted(false) + fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to anUnsupportedResult())) - val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) - fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size) - result shouldBeEqualTo RegistrationResult.Error(AN_ERROR) + result shouldBeEqualTo RegistrationActionHandler.Result.UnsupportedStage } @Test - fun `given email verification errors with 401 and succeeds, when checking email validation, then continues to poll until success`() = runTest { - val errorsToThrow = listOf( - a401ServerError(), - a401ServerError(), - a401ServerError() + fun `given flow result with mandatory and optional stages, when handling action, then returns mandatory stage`() = runTest { + val mandatoryStage = Stage.ReCaptcha(mandatory = true, "ignored-key") + val mixedStages = listOf(Stage.Email(mandatory = false), mandatoryStage) + givenFlowResult(mixedStages) + + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(mandatoryStage) + } + + @Test + fun `given flow result with only optional stages, when handling action, then returns optional stage`() = runTest { + val optionalStage = Stage.ReCaptcha(mandatory = false, "ignored-key") + givenFlowResult(listOf(optionalStage)) + + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(optionalStage) + } + + @Test + fun `given flow result with missing stages, when handling action, then returns MissingNextStage`() = runTest { + givenFlowResult(emptyList()) + + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.MissingNextStage + } + + @Test + fun `given flow result with only optional dummy stage, when handling action, then returns MissingNextStage`() = runTest { + givenFlowResult(listOf(Stage.Dummy(mandatory = false))) + + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.MissingNextStage + } + + @Test + fun `given non matrix org homeserver and flow result with missing mandatory stages, when handling action, then returns first item`() = runTest { + val firstStage = Stage.ReCaptcha(mandatory = true, "ignored-key") + val orderedStages = listOf(firstStage, Stage.Email(mandatory = true), Stage.Msisdn(mandatory = true)) + givenFlowResult(orderedStages) + + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(firstStage) + } + + @Test + fun `given matrix org homeserver and flow result with missing mandatory stages, when handling action, then returns email item first`() = runTest { + vectorFeatures.givenCombinedRegisterEnabled() + val expectedFirstItem = Stage.Email(mandatory = true) + val orderedStages = listOf(Stage.ReCaptcha(mandatory = true, "ignored-key"), expectedFirstItem, Stage.Msisdn(mandatory = true)) + givenFlowResult(orderedStages) + + val result = registrationActionHandler.processAction(state = aSelectedHomeserverState("https://matrix.org/"), RegisterAction.StartRegistration) + + result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(expectedFirstItem) + } + + @Test + fun `given password already sent and missing mandatory dummy stage, when handling action, then fast tracks the dummy stage`() = runTest { + val stages = listOf(Stage.ReCaptcha(mandatory = true, "ignored-key"), Stage.Email(mandatory = true), Stage.Dummy(mandatory = true)) + fakeAuthenticationService.givenRegistrationStarted(true) + fakeWizardActionDelegate.givenResultsFor( + listOf( + RegisterAction.StartRegistration to aFlowResult(stages), + RegisterAction.RegisterDummy to RegistrationResult.Complete(A_SESSION) + ) ) - fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow, finally = SdkResult.Success(A_SESSION)) - val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) + val result = registrationActionHandler.processAction(RegisterAction.StartRegistration) - fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size + 1) - result shouldBeEqualTo RegistrationResult.Complete(A_SESSION) + result shouldBeEqualTo RegistrationActionHandler.Result.Success(A_SESSION) } - private suspend fun testSuccessfulActionDelegation(case: Case) { - val fakeRegistrationWizard = FakeRegistrationWizard() - val registrationActionHandler = RegistrationActionHandler() - fakeRegistrationWizard.givenSuccessFor(result = A_SESSION, case.expect) - - val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, case.action) - - coVerifyAll { case.expect(fakeRegistrationWizard) } - result shouldBeEqualTo AN_EXPECTED_RESULT + private fun givenFlowResult(stages: List) { + fakeAuthenticationService.givenRegistrationStarted(true) + fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to aFlowResult(stages))) } + + private fun aFlowResult(missingStages: List) = RegistrationResult.NextStep( + FlowResult( + missingStages = missingStages, + completedStages = emptyList() + ) + ) + + private fun anUnsupportedResult() = RegistrationResult.NextStep( + FlowResult( + missingStages = listOf(Stage.Other(mandatory = true, "ignored-type", emptyMap())), + completedStages = emptyList() + ) + ) + + private suspend fun RegistrationActionHandler.processAction(action: RegisterAction) = processAction(aSelectedHomeserverState(), action) } - -private fun case(action: RegisterAction, expect: suspend RegistrationWizard.() -> SdkResult) = Case(action, expect) - -private class Case(val action: RegisterAction, val expect: suspend RegistrationWizard.() -> SdkResult) diff --git a/vector/src/test/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegateTest.kt b/vector/src/test/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegateTest.kt new file mode 100644 index 0000000000..a610486670 --- /dev/null +++ b/vector/src/test/java/im/vector/app/features/onboarding/RegistrationWizardActionDelegateTest.kt @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.features.onboarding + +import im.vector.app.test.fakes.FakeAuthenticationService +import im.vector.app.test.fakes.FakeRegistrationWizard +import im.vector.app.test.fakes.FakeSession +import im.vector.app.test.fixtures.a401ServerError +import io.mockk.coVerifyAll +import kotlinx.coroutines.test.runTest +import org.amshove.kluent.shouldBeEqualTo +import org.junit.Test +import org.matrix.android.sdk.api.auth.registration.RegisterThreePid +import org.matrix.android.sdk.api.auth.registration.RegistrationWizard +import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult + +private const val IGNORED_DELAY = 0L +private val AN_ERROR = RuntimeException() +private val A_SESSION = FakeSession() +private val AN_EXPECTED_RESULT = RegistrationResult.Complete(A_SESSION) +private const val A_USERNAME = "a username" +private const val A_PASSWORD = "a password" +private const val AN_INITIAL_DEVICE_NAME = "a device name" +private const val A_CAPTCHA_RESPONSE = "a captcha response" +private const val A_PID_CODE = "a pid code" +private const val EMAIL_VALIDATED_DELAY = 10000L +private val A_PID_TO_REGISTER = RegisterThreePid.Email("an email") + +class RegistrationWizardActionDelegateTest { + + private val fakeRegistrationWizard = FakeRegistrationWizard() + private val fakeAuthenticationService = FakeAuthenticationService().also { + it.givenRegistrationWizard(fakeRegistrationWizard) + } + private val registrationActionHandler = RegistrationWizardActionDelegate(fakeAuthenticationService) + + @Test + fun `when handling register action then delegates to wizard`() = runTest { + val cases = listOf( + case(RegisterAction.StartRegistration) { getRegistrationFlow() }, + case(RegisterAction.CaptchaDone(A_CAPTCHA_RESPONSE)) { performReCaptcha(A_CAPTCHA_RESPONSE) }, + case(RegisterAction.AcceptTerms) { acceptTerms() }, + case(RegisterAction.RegisterDummy) { dummy() }, + case(RegisterAction.AddThreePid(A_PID_TO_REGISTER)) { addThreePid(A_PID_TO_REGISTER) }, + case(RegisterAction.SendAgainThreePid) { sendAgainThreePid() }, + case(RegisterAction.ValidateThreePid(A_PID_CODE)) { handleValidateThreePid(A_PID_CODE) }, + case(RegisterAction.CheckIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY)) { checkIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY) }, + case(RegisterAction.CreateAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)) { + createAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME) + } + ) + + cases.forEach { testSuccessfulActionDelegation(it) } + } + + @Test + fun `given adding an email ThreePid fails with 401, when handling register action, then infer EmailSuccess`() = runTest { + fakeRegistrationWizard.givenAddEmailThreePidErrors( + cause = a401ServerError(), + email = A_PID_TO_REGISTER.email + ) + + val result = registrationActionHandler.executeAction(RegisterAction.AddThreePid(A_PID_TO_REGISTER)) + + result shouldBeEqualTo RegistrationResult.SendEmailSuccess(A_PID_TO_REGISTER.email) + } + + @Test + fun `given email verification errors with 401 then fatal error, when checking email validation, then continues to poll until non 401 error`() = runTest { + val errorsToThrow = listOf( + a401ServerError(), + a401ServerError(), + a401ServerError(), + AN_ERROR + ) + fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow) + + val result = registrationActionHandler.executeAction(RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) + + fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size) + result shouldBeEqualTo RegistrationResult.Error(AN_ERROR) + } + + @Test + fun `given email verification errors with 401 and succeeds, when checking email validation, then continues to poll until success`() = runTest { + val errorsToThrow = listOf( + a401ServerError(), + a401ServerError(), + a401ServerError() + ) + fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow, finally = MatrixRegistrationResult.Success(A_SESSION)) + + val result = registrationActionHandler.executeAction(RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) + + fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size + 1) + result shouldBeEqualTo RegistrationResult.Complete(A_SESSION) + } + + private suspend fun testSuccessfulActionDelegation(case: Case) { + val fakeRegistrationWizard = FakeRegistrationWizard() + val authenticationService = FakeAuthenticationService().also { it.givenRegistrationWizard(fakeRegistrationWizard) } + val registrationActionHandler = RegistrationWizardActionDelegate(authenticationService) + fakeRegistrationWizard.givenSuccessFor(result = A_SESSION, case.expect) + + val result = registrationActionHandler.executeAction(case.action) + + coVerifyAll { case.expect(fakeRegistrationWizard) } + result shouldBeEqualTo AN_EXPECTED_RESULT + } +} + +private fun case(action: RegisterAction, expect: suspend RegistrationWizard.() -> MatrixRegistrationResult) = Case(action, expect) + +private class Case(val action: RegisterAction, val expect: suspend RegistrationWizard.() -> MatrixRegistrationResult) diff --git a/vector/src/test/java/im/vector/app/test/fakes/FakeRegisterActionHandler.kt b/vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationActionHandler.kt similarity index 65% rename from vector/src/test/java/im/vector/app/test/fakes/FakeRegisterActionHandler.kt rename to vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationActionHandler.kt index f5824e5866..ddaea38302 100644 --- a/vector/src/test/java/im/vector/app/test/fakes/FakeRegisterActionHandler.kt +++ b/vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationActionHandler.kt @@ -18,23 +18,21 @@ package im.vector.app.test.fakes import im.vector.app.features.onboarding.RegisterAction import im.vector.app.features.onboarding.RegistrationActionHandler -import im.vector.app.features.onboarding.RegistrationResult import io.mockk.coEvery import io.mockk.mockk -import org.matrix.android.sdk.api.auth.registration.RegistrationWizard -class FakeRegisterActionHandler { +class FakeRegistrationActionHandler { val instance = mockk() - fun givenResultsFor(wizard: RegistrationWizard, result: List>) { - coEvery { instance.handleRegisterAction(wizard, any()) } answers { call -> + fun givenThrows(action: RegisterAction, cause: Throwable) { + coEvery { instance.processAction(any(), action) } throws cause + } + + fun givenResultsFor(result: List>) { + coEvery { instance.processAction(any(), any()) } answers { call -> val actionArg = call.invocation.args[1] as RegisterAction result.first { it.first == actionArg }.second } } - - fun givenThrowsFor(wizard: RegistrationWizard, action: RegisterAction, cause: Throwable) { - coEvery { instance.handleRegisterAction(wizard, action) } throws cause - } } diff --git a/vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationWizardActionDelegate.kt b/vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationWizardActionDelegate.kt new file mode 100644 index 0000000000..3e95be3dae --- /dev/null +++ b/vector/src/test/java/im/vector/app/test/fakes/FakeRegistrationWizardActionDelegate.kt @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.test.fakes + +import im.vector.app.features.onboarding.RegisterAction +import im.vector.app.features.onboarding.RegistrationResult +import im.vector.app.features.onboarding.RegistrationWizardActionDelegate +import io.mockk.coEvery +import io.mockk.mockk + +class FakeRegistrationWizardActionDelegate { + + val instance = mockk() + + fun givenResultsFor(result: List>) { + coEvery { instance.executeAction(any()) } answers { call -> + val actionArg = call.invocation.args[0] as RegisterAction + result.first { it.first == actionArg }.second + } + } + + fun givenThrowsFor(action: RegisterAction, cause: Throwable) { + coEvery { instance.executeAction(action) } throws cause + } +} diff --git a/vector/src/test/java/im/vector/app/test/fakes/FakeVectorFeatures.kt b/vector/src/test/java/im/vector/app/test/fakes/FakeVectorFeatures.kt index aeabcce7cd..e227a1a686 100644 --- a/vector/src/test/java/im/vector/app/test/fakes/FakeVectorFeatures.kt +++ b/vector/src/test/java/im/vector/app/test/fakes/FakeVectorFeatures.kt @@ -26,4 +26,8 @@ class FakeVectorFeatures : VectorFeatures by spyk() { fun givenPersonalisationEnabled() { every { isOnboardingPersonalizeEnabled() } returns true } + + fun givenCombinedRegisterEnabled() { + every { isOnboardingCombinedRegisterEnabled() } returns true + } } diff --git a/vector/src/test/java/im/vector/app/test/fixtures/SelectedHomeserverStateFixture.kt b/vector/src/test/java/im/vector/app/test/fixtures/SelectedHomeserverStateFixture.kt new file mode 100644 index 0000000000..9d7ff7d291 --- /dev/null +++ b/vector/src/test/java/im/vector/app/test/fixtures/SelectedHomeserverStateFixture.kt @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.test.fixtures + +import im.vector.app.features.onboarding.SelectedHomeserverState + +object SelectedHomeserverStateFixture { + + fun aSelectedHomeserverState( + userFacingUrl: String = "https://myhomeserver.com", + ) = SelectedHomeserverState(userFacingUrl = userFacingUrl) +}