extracting registration action business logic to the handler abstraction and adding tests

- renames the existing handler to a wizard delegate
This commit is contained in:
Adam Brown 2022-05-17 17:47:42 +01:00
parent 928183ff64
commit ba18c6f3e2
12 changed files with 599 additions and 328 deletions

View File

@ -20,7 +20,7 @@ package im.vector.app.features.onboarding
import im.vector.app.core.platform.VectorViewEvents import im.vector.app.core.platform.VectorViewEvents
import im.vector.app.features.login.ServerType import im.vector.app.features.login.ServerType
import im.vector.app.features.login.SignMode import im.vector.app.features.login.SignMode
import org.matrix.android.sdk.api.auth.registration.FlowResult import org.matrix.android.sdk.api.auth.registration.Stage
/** /**
* Transient events for Login. * Transient events for Login.
@ -30,7 +30,9 @@ sealed class OnboardingViewEvents : VectorViewEvents {
data class Failure(val throwable: Throwable) : OnboardingViewEvents() data class Failure(val throwable: Throwable) : OnboardingViewEvents()
data class DeeplinkAuthenticationFailure(val retryAction: OnboardingAction) : OnboardingViewEvents() data class DeeplinkAuthenticationFailure(val retryAction: OnboardingAction) : OnboardingViewEvents()
data class RegistrationFlowResult(val flowResult: FlowResult, val isRegistrationStarted: Boolean) : OnboardingViewEvents() object DisplayRegistrationFallback : OnboardingViewEvents()
data class DisplayRegistrationStage(val stage: Stage) : OnboardingViewEvents()
object DisplayStartRegistration : OnboardingViewEvents()
object OutdatedHomeserver : OnboardingViewEvents() object OutdatedHomeserver : OnboardingViewEvents()
// Navigation event // Navigation event

View File

@ -47,7 +47,6 @@ import im.vector.app.features.login.ServerType
import im.vector.app.features.login.SignMode import im.vector.app.features.login.SignMode
import im.vector.app.features.onboarding.OnboardingAction.AuthenticateAction import im.vector.app.features.onboarding.OnboardingAction.AuthenticateAction
import im.vector.app.features.onboarding.StartAuthenticationFlowUseCase.StartAuthenticationResult import im.vector.app.features.onboarding.StartAuthenticationFlowUseCase.StartAuthenticationResult
import im.vector.app.features.onboarding.ftueauth.MatrixOrgRegistrationStagesComparator
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -56,9 +55,7 @@ import org.matrix.android.sdk.api.auth.HomeServerHistoryService
import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig
import org.matrix.android.sdk.api.auth.data.SsoIdentityProvider import org.matrix.android.sdk.api.auth.data.SsoIdentityProvider
import org.matrix.android.sdk.api.auth.login.LoginWizard import org.matrix.android.sdk.api.auth.login.LoginWizard
import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.failure.isHomeserverUnavailable import org.matrix.android.sdk.api.failure.isHomeserverUnavailable
import org.matrix.android.sdk.api.session.Session import org.matrix.android.sdk.api.session.Session
import timber.log.Timber import timber.log.Timber
@ -80,11 +77,11 @@ class OnboardingViewModel @AssistedInject constructor(
private val vectorFeatures: VectorFeatures, private val vectorFeatures: VectorFeatures,
private val analyticsTracker: AnalyticsTracker, private val analyticsTracker: AnalyticsTracker,
private val uriFilenameResolver: UriFilenameResolver, private val uriFilenameResolver: UriFilenameResolver,
private val registrationActionHandler: RegistrationActionHandler,
private val directLoginUseCase: DirectLoginUseCase, private val directLoginUseCase: DirectLoginUseCase,
private val startAuthenticationFlowUseCase: StartAuthenticationFlowUseCase, private val startAuthenticationFlowUseCase: StartAuthenticationFlowUseCase,
private val vectorOverrides: VectorOverrides, private val vectorOverrides: VectorOverrides,
private val buildMeta: BuildMeta private val registrationActionHandler: RegistrationActionHandler,
private val buildMeta: BuildMeta,
) : VectorViewModel<OnboardingViewState, OnboardingAction, OnboardingViewEvents>(initialState) { ) : VectorViewModel<OnboardingViewState, OnboardingAction, OnboardingViewEvents>(initialState) {
@AssistedFactory @AssistedFactory
@ -150,18 +147,18 @@ class OnboardingViewModel @AssistedInject constructor(
is OnboardingAction.WebLoginSuccess -> handleWebLoginSuccess(action) is OnboardingAction.WebLoginSuccess -> handleWebLoginSuccess(action)
is OnboardingAction.ResetPassword -> handleResetPassword(action) is OnboardingAction.ResetPassword -> handleResetPassword(action)
is OnboardingAction.ResetPasswordMailConfirmed -> handleResetPasswordMailConfirmed() is OnboardingAction.ResetPasswordMailConfirmed -> handleResetPasswordMailConfirmed()
is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction, ::emitFlowResultViewEvent) is OnboardingAction.PostRegisterAction -> handleRegisterAction(action.registerAction)
is OnboardingAction.ResetAction -> handleResetAction(action) is OnboardingAction.ResetAction -> handleResetAction(action)
is OnboardingAction.UserAcceptCertificate -> handleUserAcceptCertificate(action) is OnboardingAction.UserAcceptCertificate -> handleUserAcceptCertificate(action)
OnboardingAction.ClearHomeServerHistory -> handleClearHomeServerHistory() OnboardingAction.ClearHomeServerHistory -> handleClearHomeServerHistory()
is OnboardingAction.UpdateDisplayName -> updateDisplayName(action.displayName) is OnboardingAction.UpdateDisplayName -> updateDisplayName(action.displayName)
OnboardingAction.UpdateDisplayNameSkipped -> handleDisplayNameStepComplete() OnboardingAction.UpdateDisplayNameSkipped -> handleDisplayNameStepComplete()
OnboardingAction.UpdateProfilePictureSkipped -> completePersonalization() OnboardingAction.UpdateProfilePictureSkipped -> completePersonalization()
OnboardingAction.PersonalizeProfile -> handlePersonalizeProfile() OnboardingAction.PersonalizeProfile -> handlePersonalizeProfile()
is OnboardingAction.ProfilePictureSelected -> handleProfilePictureSelected(action) is OnboardingAction.ProfilePictureSelected -> handleProfilePictureSelected(action)
OnboardingAction.SaveSelectedProfilePicture -> updateProfilePicture() OnboardingAction.SaveSelectedProfilePicture -> updateProfilePicture()
is OnboardingAction.PostViewEvent -> _viewEvents.post(action.viewEvent) is OnboardingAction.PostViewEvent -> _viewEvents.post(action.viewEvent)
OnboardingAction.StopEmailValidationCheck -> cancelWaitForEmailValidation() OnboardingAction.StopEmailValidationCheck -> cancelWaitForEmailValidation()
} }
} }
@ -259,12 +256,12 @@ class OnboardingViewModel @AssistedInject constructor(
} }
} }
private fun handleRegisterAction(action: RegisterAction, onNextRegistrationStepAction: (FlowResult) -> Unit) { private fun handleRegisterAction(action: RegisterAction) {
val job = viewModelScope.launch { val job = viewModelScope.launch {
if (action.hasLoadingState()) { if (action.hasLoadingState()) {
setState { copy(isLoading = true) } setState { copy(isLoading = true) }
} }
internalRegisterAction(action, onNextRegistrationStepAction) internalRegisterAction(action)
setState { copy(isLoading = false) } setState { copy(isLoading = false) }
} }
@ -275,23 +272,28 @@ class OnboardingViewModel @AssistedInject constructor(
} }
} }
private suspend fun internalRegisterAction(action: RegisterAction, onNextRegistrationStepAction: (FlowResult) -> Unit) { private suspend fun internalRegisterAction(action: RegisterAction, overrideNextStage: (() -> Unit)? = null) {
runCatching { registrationActionHandler.handleRegisterAction(registrationWizard, action) } runCatching { registrationActionHandler.processAction(awaitState().selectedHomeserver, action) }
.fold( .fold(
onSuccess = { onSuccess = {
when { when (it) {
action.ignoresResult() -> { RegistrationActionHandler.Result.Ignored -> {
// do nothing // do nothing
} }
else -> when (it) { is RegistrationActionHandler.Result.NextStage -> {
is RegistrationResult.Complete -> onSessionCreated( overrideNextStage?.invoke() ?: _viewEvents.post(OnboardingViewEvents.DisplayRegistrationStage(it.stage))
it.session, }
authenticationDescription = awaitState().selectedAuthenticationState.description is RegistrationActionHandler.Result.Success -> onSessionCreated(
?: AuthenticationDescription.Register(AuthenticationDescription.AuthenticationType.Other) it.session,
) authenticationDescription = awaitState().selectedAuthenticationState.description
is RegistrationResult.NextStep -> onFlowResponse(it.flowResult, onNextRegistrationStepAction) ?: AuthenticationDescription.Register(AuthenticationDescription.AuthenticationType.Other)
is RegistrationResult.SendEmailSuccess -> _viewEvents.post(OnboardingViewEvents.OnSendEmailSuccess(it.email)) )
is RegistrationResult.Error -> _viewEvents.post(OnboardingViewEvents.Failure(it.cause)) RegistrationActionHandler.Result.StartRegistration -> _viewEvents.post(OnboardingViewEvents.DisplayStartRegistration)
RegistrationActionHandler.Result.UnsupportedStage -> _viewEvents.post(OnboardingViewEvents.DisplayRegistrationFallback)
is RegistrationActionHandler.Result.SendEmailSuccess -> _viewEvents.post(OnboardingViewEvents.OnSendEmailSuccess(it.email))
is RegistrationActionHandler.Result.Error -> _viewEvents.post(OnboardingViewEvents.Failure(it.cause))
RegistrationActionHandler.Result.MissingNextStage -> {
_viewEvents.post(OnboardingViewEvents.Failure(IllegalStateException("No next registration stage found")))
} }
} }
}, },
@ -303,18 +305,6 @@ class OnboardingViewModel @AssistedInject constructor(
) )
} }
private fun emitFlowResultViewEvent(flowResult: FlowResult) {
withState { state ->
val orderedResult = when {
state.hasSelectedMatrixOrg() && vectorFeatures.isOnboardingCombinedRegisterEnabled() -> flowResult.copy(
missingStages = flowResult.missingStages.sortedWith(MatrixOrgRegistrationStagesComparator())
)
else -> flowResult
}
_viewEvents.post(OnboardingViewEvents.RegistrationFlowResult(orderedResult, isRegistrationStarted))
}
}
private fun OnboardingViewState.hasSelectedMatrixOrg() = selectedHomeserver.userFacingUrl == matrixOrgUrl private fun OnboardingViewState.hasSelectedMatrixOrg() = selectedHomeserver.userFacingUrl == matrixOrgUrl
private fun handleRegisterWith(action: AuthenticateAction.Register) { private fun handleRegisterWith(action: AuthenticateAction.Register) {
@ -328,8 +318,7 @@ class OnboardingViewModel @AssistedInject constructor(
action.username, action.username,
action.password, action.password,
action.initialDeviceName action.initialDeviceName
), )
::emitFlowResultViewEvent
) )
} }
@ -382,8 +371,8 @@ class OnboardingViewModel @AssistedInject constructor(
private fun handleUpdateSignMode(action: OnboardingAction.UpdateSignMode) { private fun handleUpdateSignMode(action: OnboardingAction.UpdateSignMode) {
updateSignMode(action.signMode) updateSignMode(action.signMode)
when (action.signMode) { when (action.signMode) {
SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration, ::emitFlowResultViewEvent) SignMode.SignUp -> handleRegisterAction(RegisterAction.StartRegistration)
SignMode.SignIn -> startAuthenticationFlow() SignMode.SignIn -> startAuthenticationFlow()
SignMode.SignInWithMatrixId -> _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignInWithMatrixId)) SignMode.SignInWithMatrixId -> _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignInWithMatrixId))
SignMode.Unknown -> Unit SignMode.Unknown -> Unit
} }
@ -530,19 +519,6 @@ class OnboardingViewModel @AssistedInject constructor(
_viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignIn)) _viewEvents.post(OnboardingViewEvents.OnSignModeSelected(SignMode.SignIn))
} }
private suspend fun onFlowResponse(flowResult: FlowResult, onNextRegistrationStepAction: (FlowResult) -> Unit) {
// If dummy stage is mandatory, and password is already sent, do the dummy stage now
if (isRegistrationStarted && flowResult.missingStages.any { it is Stage.Dummy && it.mandatory }) {
handleRegisterDummy(onNextRegistrationStepAction)
} else {
onNextRegistrationStepAction(flowResult)
}
}
private suspend fun handleRegisterDummy(onNextRegistrationStepAction: (FlowResult) -> Unit) {
internalRegisterAction(RegisterAction.RegisterDummy, onNextRegistrationStepAction)
}
private suspend fun onSessionCreated(session: Session, authenticationDescription: AuthenticationDescription) { private suspend fun onSessionCreated(session: Session, authenticationDescription: AuthenticationDescription) {
val state = awaitState() val state = awaitState()
state.useCase?.let { useCase -> state.useCase?.let { useCase ->
@ -684,7 +660,7 @@ class OnboardingViewModel @AssistedInject constructor(
} }
OnboardingFlow.SignUp -> { OnboardingFlow.SignUp -> {
updateSignMode(SignMode.SignUp) updateSignMode(SignMode.SignUp)
internalRegisterAction(RegisterAction.StartRegistration, ::emitFlowResultViewEvent) internalRegisterAction(RegisterAction.StartRegistration)
} }
OnboardingFlow.SignInSignUp, OnboardingFlow.SignInSignUp,
null -> { null -> {

View File

@ -16,105 +16,91 @@
package im.vector.app.features.onboarding package im.vector.app.features.onboarding
import im.vector.app.R
import im.vector.app.core.resources.StringProvider
import im.vector.app.core.utils.ensureTrailingSlash
import im.vector.app.features.VectorFeatures
import im.vector.app.features.VectorOverrides
import im.vector.app.features.login.isSupported
import im.vector.app.features.onboarding.ftueauth.MatrixOrgRegistrationStagesComparator
import kotlinx.coroutines.flow.first
import org.matrix.android.sdk.api.auth.AuthenticationService
import org.matrix.android.sdk.api.auth.registration.FlowResult import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegisterThreePid import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.auth.registration.RegistrationResult.FlowResponse
import org.matrix.android.sdk.api.auth.registration.RegistrationResult.Success
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
import org.matrix.android.sdk.api.failure.is401
import org.matrix.android.sdk.api.session.Session import org.matrix.android.sdk.api.session.Session
import javax.inject.Inject import javax.inject.Inject
import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult
class RegistrationActionHandler @Inject constructor() { class RegistrationActionHandler @Inject constructor(
private val registrationWizardActionDelegate: RegistrationWizardActionDelegate,
private val authenticationService: AuthenticationService,
private val vectorOverrides: VectorOverrides,
private val vectorFeatures: VectorFeatures,
stringProvider: StringProvider
) {
suspend fun handleRegisterAction(registrationWizard: RegistrationWizard, action: RegisterAction): RegistrationResult { private val matrixOrgUrl = stringProvider.getString(R.string.matrix_org_server_url).ensureTrailingSlash()
return when (action) {
RegisterAction.StartRegistration -> resultOf { registrationWizard.getRegistrationFlow() } suspend fun processAction(state: SelectedHomeserverState, action: RegisterAction): Result {
is RegisterAction.CaptchaDone -> resultOf { registrationWizard.performReCaptcha(action.captchaResponse) } val result = registrationWizardActionDelegate.executeAction(action)
is RegisterAction.AcceptTerms -> resultOf { registrationWizard.acceptTerms() } return when {
is RegisterAction.RegisterDummy -> resultOf { registrationWizard.dummy() } action.ignoresResult() -> Result.Ignored
is RegisterAction.AddThreePid -> handleAddThreePid(registrationWizard, action) else -> when (result) {
is RegisterAction.SendAgainThreePid -> resultOf { registrationWizard.sendAgainThreePid() } is RegistrationResult.Complete -> Result.Success(result.session)
is RegisterAction.ValidateThreePid -> resultOf { registrationWizard.handleValidateThreePid(action.code) } is RegistrationResult.NextStep -> processFlowResult(result, state)
is RegisterAction.CheckIfEmailHasBeenValidated -> handleCheckIfEmailIsValidated(registrationWizard, action.delayMillis) is RegistrationResult.SendEmailSuccess -> Result.SendEmailSuccess(result.email)
is RegisterAction.CreateAccount -> resultOf { is RegistrationResult.Error -> Result.Error(result.cause)
registrationWizard.createAccount(
action.username,
action.password,
action.initialDeviceName
)
} }
} }
} }
private suspend fun handleAddThreePid(wizard: RegistrationWizard, action: RegisterAction.AddThreePid): RegistrationResult { private suspend fun processFlowResult(result: RegistrationResult.NextStep, state: SelectedHomeserverState): Result {
return runCatching { wizard.addThreePid(action.threePid) }.fold( // If dummy stage is mandatory, and password is already sent, do the dummy stage now
onSuccess = { it.toRegistrationResult() }, return if (authenticationService.isRegistrationStarted() && result.flowResult.missingStages.hasMandatoryDummy()) {
onFailure = { processAction(state, RegisterAction.RegisterDummy)
when { } else {
action.threePid is RegisterThreePid.Email && it.is401() -> RegistrationResult.SendEmailSuccess(action.threePid.email) handleNextStep(state, result.flowResult)
else -> RegistrationResult.Error(it) }
}
}
)
} }
private tailrec suspend fun handleCheckIfEmailIsValidated(registrationWizard: RegistrationWizard, delayMillis: Long): RegistrationResult { private suspend fun handleNextStep(state: SelectedHomeserverState, flowResult: FlowResult): Result {
return runCatching { registrationWizard.checkIfEmailHasBeenValidated(delayMillis) }.fold( return when {
onSuccess = { it.toRegistrationResult() }, flowResult.registrationShouldFallback() -> Result.UnsupportedStage
onFailure = { authenticationService.isRegistrationStarted() -> findNextStage(state, flowResult)
when { else -> Result.StartRegistration
it.is401() -> null // recursively continue to check with a delay }
else -> RegistrationResult.Error(it) }
}
} private fun findNextStage(state: SelectedHomeserverState, flowResult: FlowResult): Result {
) ?: handleCheckIfEmailIsValidated(registrationWizard, 10_000) val orderedResult = when {
state.hasSelectedMatrixOrg() && vectorFeatures.isOnboardingCombinedRegisterEnabled() -> flowResult.copy(
missingStages = flowResult.missingStages.sortedWith(MatrixOrgRegistrationStagesComparator())
)
else -> flowResult
}
return orderedResult.findNextRegistrationStage()
?.let { Result.NextStage(it) }
?: Result.MissingNextStage
}
private fun FlowResult.findNextRegistrationStage() = missingStages.firstMandatoryOrNull() ?: missingStages.ignoreDummy().firstOptionalOrNull()
private suspend fun FlowResult.registrationShouldFallback() = vectorOverrides.forceLoginFallback.first() || missingStages.any { !it.isSupported() }
private fun SelectedHomeserverState.hasSelectedMatrixOrg() = userFacingUrl == matrixOrgUrl
sealed interface Result {
data class Success(val session: Session) : Result
data class NextStage(val stage: Stage) : Result
data class Error(val cause: Throwable) : Result
data class SendEmailSuccess(val email: String) : Result
object MissingNextStage : Result
object StartRegistration : Result
object UnsupportedStage : Result
object Ignored : Result
} }
} }
private inline fun resultOf(block: () -> MatrixRegistrationResult): RegistrationResult { private fun List<Stage>.firstMandatoryOrNull() = firstOrNull { it.mandatory }
return runCatching { block() }.fold( private fun List<Stage>.firstOptionalOrNull() = firstOrNull { !it.mandatory }
onSuccess = { it.toRegistrationResult() }, private fun List<Stage>.ignoreDummy() = filter { it !is Stage.Dummy }
onFailure = { RegistrationResult.Error(it) } private fun List<Stage>.hasMandatoryDummy() = any { it is Stage.Dummy && it.mandatory }
)
}
private fun MatrixRegistrationResult.toRegistrationResult() = when (this) {
is FlowResponse -> RegistrationResult.NextStep(flowResult)
is Success -> RegistrationResult.Complete(session)
}
sealed interface RegistrationResult {
data class Error(val cause: Throwable) : RegistrationResult
data class Complete(val session: Session) : RegistrationResult
data class NextStep(val flowResult: FlowResult) : RegistrationResult
data class SendEmailSuccess(val email: String) : RegistrationResult
}
sealed interface RegisterAction {
object StartRegistration : RegisterAction
data class CreateAccount(val username: String, val password: String, val initialDeviceName: String) : RegisterAction
data class AddThreePid(val threePid: RegisterThreePid) : RegisterAction
object SendAgainThreePid : RegisterAction
// TODO Confirm Email (from link in the email, open in the phone, intercepted by the app)
data class ValidateThreePid(val code: String) : RegisterAction
data class CheckIfEmailHasBeenValidated(val delayMillis: Long) : RegisterAction
data class CaptchaDone(val captchaResponse: String) : RegisterAction
object AcceptTerms : RegisterAction
object RegisterDummy : RegisterAction
}
fun RegisterAction.ignoresResult() = when (this) {
is RegisterAction.SendAgainThreePid -> true
else -> false
}
fun RegisterAction.hasLoadingState() = when (this) {
is RegisterAction.CheckIfEmailHasBeenValidated -> false
else -> true
}

View File

@ -0,0 +1,123 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.features.onboarding
import org.matrix.android.sdk.api.auth.AuthenticationService
import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegisterThreePid
import org.matrix.android.sdk.api.auth.registration.RegistrationResult.FlowResponse
import org.matrix.android.sdk.api.auth.registration.RegistrationResult.Success
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
import org.matrix.android.sdk.api.failure.is401
import org.matrix.android.sdk.api.session.Session
import javax.inject.Inject
import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult
class RegistrationWizardActionDelegate @Inject constructor(
private val authenticationService: AuthenticationService
) {
private val registrationWizard: RegistrationWizard
get() = authenticationService.getRegistrationWizard()
suspend fun executeAction(action: RegisterAction): RegistrationResult {
return when (action) {
RegisterAction.StartRegistration -> resultOf { registrationWizard.getRegistrationFlow() }
is RegisterAction.CaptchaDone -> resultOf { registrationWizard.performReCaptcha(action.captchaResponse) }
is RegisterAction.AcceptTerms -> resultOf { registrationWizard.acceptTerms() }
is RegisterAction.RegisterDummy -> resultOf { registrationWizard.dummy() }
is RegisterAction.AddThreePid -> handleAddThreePid(registrationWizard, action)
is RegisterAction.SendAgainThreePid -> resultOf { registrationWizard.sendAgainThreePid() }
is RegisterAction.ValidateThreePid -> resultOf { registrationWizard.handleValidateThreePid(action.code) }
is RegisterAction.CheckIfEmailHasBeenValidated -> handleCheckIfEmailIsValidated(registrationWizard, action.delayMillis)
is RegisterAction.CreateAccount -> resultOf {
registrationWizard.createAccount(
action.username,
action.password,
action.initialDeviceName
)
}
}
}
private suspend fun handleAddThreePid(wizard: RegistrationWizard, action: RegisterAction.AddThreePid): RegistrationResult {
return runCatching { wizard.addThreePid(action.threePid) }.fold(
onSuccess = { it.toRegistrationResult() },
onFailure = {
when {
action.threePid is RegisterThreePid.Email && it.is401() -> RegistrationResult.SendEmailSuccess(action.threePid.email)
else -> RegistrationResult.Error(it)
}
}
)
}
private tailrec suspend fun handleCheckIfEmailIsValidated(registrationWizard: RegistrationWizard, delayMillis: Long): RegistrationResult {
return runCatching { registrationWizard.checkIfEmailHasBeenValidated(delayMillis) }.fold(
onSuccess = { it.toRegistrationResult() },
onFailure = {
when {
it.is401() -> null // recursively continue to check with a delay
else -> RegistrationResult.Error(it)
}
}
) ?: handleCheckIfEmailIsValidated(registrationWizard, 10_000)
}
}
private inline fun resultOf(block: () -> MatrixRegistrationResult): RegistrationResult {
return runCatching { block() }.fold(
onSuccess = { it.toRegistrationResult() },
onFailure = { RegistrationResult.Error(it) }
)
}
private fun MatrixRegistrationResult.toRegistrationResult() = when (this) {
is FlowResponse -> RegistrationResult.NextStep(flowResult)
is Success -> RegistrationResult.Complete(session)
}
sealed interface RegistrationResult {
data class Error(val cause: Throwable) : RegistrationResult
data class Complete(val session: Session) : RegistrationResult
data class NextStep(val flowResult: FlowResult) : RegistrationResult
data class SendEmailSuccess(val email: String) : RegistrationResult
}
sealed interface RegisterAction {
object StartRegistration : RegisterAction
data class CreateAccount(val username: String, val password: String, val initialDeviceName: String) : RegisterAction
data class AddThreePid(val threePid: RegisterThreePid) : RegisterAction
object SendAgainThreePid : RegisterAction
data class ValidateThreePid(val code: String) : RegisterAction
data class CheckIfEmailHasBeenValidated(val delayMillis: Long) : RegisterAction
data class CaptchaDone(val captchaResponse: String) : RegisterAction
object AcceptTerms : RegisterAction
object RegisterDummy : RegisterAction
}
fun RegisterAction.ignoresResult() = when (this) {
is RegisterAction.SendAgainThreePid -> true
else -> false
}
fun RegisterAction.hasLoadingState() = when (this) {
is RegisterAction.CheckIfEmailHasBeenValidated -> false
else -> true
}

View File

@ -44,7 +44,6 @@ import im.vector.app.features.login.LoginMode
import im.vector.app.features.login.ServerType import im.vector.app.features.login.ServerType
import im.vector.app.features.login.SignMode import im.vector.app.features.login.SignMode
import im.vector.app.features.login.TextInputFormFragmentMode import im.vector.app.features.login.TextInputFormFragmentMode
import im.vector.app.features.login.isSupported
import im.vector.app.features.onboarding.OnboardingAction import im.vector.app.features.onboarding.OnboardingAction
import im.vector.app.features.onboarding.OnboardingActivity import im.vector.app.features.onboarding.OnboardingActivity
import im.vector.app.features.onboarding.OnboardingVariant import im.vector.app.features.onboarding.OnboardingVariant
@ -129,10 +128,7 @@ class FtueAuthVariant(
private fun handleOnboardingViewEvents(viewEvents: OnboardingViewEvents) { private fun handleOnboardingViewEvents(viewEvents: OnboardingViewEvents) {
when (viewEvents) { when (viewEvents) {
is OnboardingViewEvents.RegistrationFlowResult -> { is OnboardingViewEvents.OutdatedHomeserver -> {
onRegistrationFlow(viewEvents)
}
is OnboardingViewEvents.OutdatedHomeserver -> {
MaterialAlertDialogBuilder(activity) MaterialAlertDialogBuilder(activity)
.setTitle(R.string.login_error_outdated_homeserver_title) .setTitle(R.string.login_error_outdated_homeserver_title)
.setMessage(R.string.login_error_outdated_homeserver_warning_content) .setMessage(R.string.login_error_outdated_homeserver_warning_content)
@ -227,9 +223,15 @@ class FtueAuthVariant(
option = commonOption option = commonOption
) )
} }
OnboardingViewEvents.OnHomeserverEdited -> activity.popBackstack() OnboardingViewEvents.OnHomeserverEdited -> activity.popBackstack()
OnboardingViewEvents.OpenCombinedLogin -> onStartCombinedLogin() OnboardingViewEvents.OpenCombinedLogin -> onStartCombinedLogin()
is OnboardingViewEvents.DeeplinkAuthenticationFailure -> onDeeplinkedHomeserverUnavailable(viewEvents) is OnboardingViewEvents.DeeplinkAuthenticationFailure -> onDeeplinkedHomeserverUnavailable(viewEvents)
OnboardingViewEvents.DisplayRegistrationFallback -> displayFallbackWebDialog()
is OnboardingViewEvents.DisplayRegistrationStage -> doStage(viewEvents.stage)
OnboardingViewEvents.DisplayStartRegistration -> when {
vectorFeatures.isOnboardingCombinedRegisterEnabled() -> openStartCombinedRegister()
else -> openAuthLoginFragmentWithTag(FRAGMENT_REGISTRATION_STAGE_TAG)
}
} }
} }
@ -253,25 +255,10 @@ class FtueAuthVariant(
addRegistrationStageFragmentToBackstack(FtueAuthCombinedLoginFragment::class.java) addRegistrationStageFragmentToBackstack(FtueAuthCombinedLoginFragment::class.java)
} }
private fun onRegistrationFlow(viewEvents: OnboardingViewEvents.RegistrationFlowResult) {
when {
registrationShouldFallback(viewEvents) -> displayFallbackWebDialog()
viewEvents.isRegistrationStarted -> handleRegistrationNavigation(viewEvents.flowResult.missingStages)
vectorFeatures.isOnboardingCombinedRegisterEnabled() -> openStartCombinedRegister()
else -> openAuthLoginFragmentWithTag(FRAGMENT_REGISTRATION_STAGE_TAG)
}
}
private fun openStartCombinedRegister() { private fun openStartCombinedRegister() {
addRegistrationStageFragmentToBackstack(FtueAuthCombinedRegisterFragment::class.java) addRegistrationStageFragmentToBackstack(FtueAuthCombinedRegisterFragment::class.java)
} }
private fun registrationShouldFallback(registrationFlowResult: OnboardingViewEvents.RegistrationFlowResult) =
isForceLoginFallbackEnabled || registrationFlowResult.containsUnsupportedRegistrationFlow()
private fun OnboardingViewEvents.RegistrationFlowResult.containsUnsupportedRegistrationFlow() =
flowResult.missingStages.any { !it.isSupported() }
private fun displayFallbackWebDialog() { private fun displayFallbackWebDialog() {
MaterialAlertDialogBuilder(activity) MaterialAlertDialogBuilder(activity)
.setTitle(R.string.app_name) .setTitle(R.string.app_name)
@ -381,23 +368,6 @@ class FtueAuthVariant(
?.let { onboardingViewModel.handle(OnboardingAction.LoginWithToken(it)) } ?.let { onboardingViewModel.handle(OnboardingAction.LoginWithToken(it)) }
} }
private fun handleRegistrationNavigation(remainingStages: List<Stage>) {
// Complete all mandatory stages first
val mandatoryStage = remainingStages.firstOrNull { it.mandatory }
if (mandatoryStage != null) {
doStage(mandatoryStage)
} else {
// Consider optional stages
val optionalStage = remainingStages.firstOrNull { !it.mandatory && it !is Stage.Dummy }
if (optionalStage == null) {
// Should not happen...
} else {
doStage(optionalStage)
}
}
}
private fun doStage(stage: Stage) { private fun doStage(stage: Stage) {
// Ensure there is no fragment for registration stage in the backstack // Ensure there is no fragment for registration stage in the backstack
supportFragmentManager.popBackStack(FRAGMENT_REGISTRATION_STAGE_TAG, FragmentManager.POP_BACK_STACK_INCLUSIVE) supportFragmentManager.popBackStack(FRAGMENT_REGISTRATION_STAGE_TAG, FragmentManager.POP_BACK_STACK_INCLUSIVE)

View File

@ -31,9 +31,8 @@ import im.vector.app.test.fakes.FakeContext
import im.vector.app.test.fakes.FakeDirectLoginUseCase import im.vector.app.test.fakes.FakeDirectLoginUseCase
import im.vector.app.test.fakes.FakeHomeServerConnectionConfigFactory import im.vector.app.test.fakes.FakeHomeServerConnectionConfigFactory
import im.vector.app.test.fakes.FakeHomeServerHistoryService import im.vector.app.test.fakes.FakeHomeServerHistoryService
import im.vector.app.test.fakes.FakeRegistrationActionHandler
import im.vector.app.test.fakes.FakeLoginWizard import im.vector.app.test.fakes.FakeLoginWizard
import im.vector.app.test.fakes.FakeRegisterActionHandler
import im.vector.app.test.fakes.FakeRegistrationWizard
import im.vector.app.test.fakes.FakeSession import im.vector.app.test.fakes.FakeSession
import im.vector.app.test.fakes.FakeStartAuthenticationFlowUseCase import im.vector.app.test.fakes.FakeStartAuthenticationFlowUseCase
import im.vector.app.test.fakes.FakeStringProvider import im.vector.app.test.fakes.FakeStringProvider
@ -50,7 +49,6 @@ import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig
import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.session.Session import org.matrix.android.sdk.api.session.Session
import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilities import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilities
@ -62,8 +60,7 @@ private val A_LOADABLE_REGISTER_ACTION = RegisterAction.StartRegistration
private val A_NON_LOADABLE_REGISTER_ACTION = RegisterAction.CheckIfEmailHasBeenValidated(delayMillis = -1L) private val A_NON_LOADABLE_REGISTER_ACTION = RegisterAction.CheckIfEmailHasBeenValidated(delayMillis = -1L)
private val A_RESULT_IGNORED_REGISTER_ACTION = RegisterAction.SendAgainThreePid private val A_RESULT_IGNORED_REGISTER_ACTION = RegisterAction.SendAgainThreePid
private val A_HOMESERVER_CAPABILITIES = aHomeServerCapabilities(canChangeDisplayName = true, canChangeAvatar = true) private val A_HOMESERVER_CAPABILITIES = aHomeServerCapabilities(canChangeDisplayName = true, canChangeAvatar = true)
private val AN_IGNORED_FLOW_RESULT = FlowResult(missingStages = emptyList(), completedStages = emptyList()) private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationActionHandler.Result.NextStage(Stage.Dummy(mandatory = true))
private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT)
private val A_DIRECT_LOGIN = OnboardingAction.AuthenticateAction.LoginDirect("@a-user:id.org", "a-password", "a-device-name") private val A_DIRECT_LOGIN = OnboardingAction.AuthenticateAction.LoginDirect("@a-user:id.org", "a-password", "a-device-name")
private const val A_HOMESERVER_URL = "https://edited-homeserver.org" private const val A_HOMESERVER_URL = "https://edited-homeserver.org"
private val A_HOMESERVER_CONFIG = HomeServerConnectionConfig(FakeUri().instance) private val A_HOMESERVER_CONFIG = HomeServerConnectionConfig(FakeUri().instance)
@ -82,7 +79,7 @@ class OnboardingViewModelTest {
private val fakeUriFilenameResolver = FakeUriFilenameResolver() private val fakeUriFilenameResolver = FakeUriFilenameResolver()
private val fakeActiveSessionHolder = FakeActiveSessionHolder(fakeSession) private val fakeActiveSessionHolder = FakeActiveSessionHolder(fakeSession)
private val fakeAuthenticationService = FakeAuthenticationService() private val fakeAuthenticationService = FakeAuthenticationService()
private val fakeRegisterActionHandler = FakeRegisterActionHandler() private val fakeRegistrationActionHandler = FakeRegistrationActionHandler()
private val fakeDirectLoginUseCase = FakeDirectLoginUseCase() private val fakeDirectLoginUseCase = FakeDirectLoginUseCase()
private val fakeVectorFeatures = FakeVectorFeatures() private val fakeVectorFeatures = FakeVectorFeatures()
private val fakeHomeServerConnectionConfigFactory = FakeHomeServerConnectionConfigFactory() private val fakeHomeServerConnectionConfigFactory = FakeHomeServerConnectionConfigFactory()
@ -199,7 +196,7 @@ class OnboardingViewModelTest {
{ copy(isLoading = true) }, { copy(isLoading = true) },
{ copy(isLoading = false) } { copy(isLoading = false) }
) )
.assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage))
.finish() .finish()
} }
@ -216,7 +213,7 @@ class OnboardingViewModelTest {
{ copy(isLoading = true) }, { copy(isLoading = true) },
{ copy(isLoading = false) } { copy(isLoading = false) }
) )
.assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage))
.finish() .finish()
} }
@ -229,14 +226,14 @@ class OnboardingViewModelTest {
test test
.assertState(initialState) .assertState(initialState)
.assertEvents(OnboardingViewEvents.RegistrationFlowResult(ANY_CONTINUING_REGISTRATION_RESULT.flowResult, isRegistrationStarted = true)) .assertEvents(OnboardingViewEvents.DisplayRegistrationStage(ANY_CONTINUING_REGISTRATION_RESULT.stage))
.finish() .finish()
} }
@Test @Test
fun `given register action ignores result, when handling action, then does nothing on success`() = runTest { fun `given register action ignores result, when handling action, then does nothing on success`() = runTest {
val test = viewModel.test() val test = viewModel.test()
givenRegistrationResultFor(A_RESULT_IGNORED_REGISTER_ACTION, RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT)) givenRegistrationResultFor(A_RESULT_IGNORED_REGISTER_ACTION, RegistrationActionHandler.Result.Ignored)
viewModel.handle(OnboardingAction.PostRegisterAction(A_RESULT_IGNORED_REGISTER_ACTION)) viewModel.handle(OnboardingAction.PostRegisterAction(A_RESULT_IGNORED_REGISTER_ACTION))
@ -276,7 +273,7 @@ class OnboardingViewModelTest {
viewModelWith(initialState.copy(onboardingFlow = OnboardingFlow.SignUp)) viewModelWith(initialState.copy(onboardingFlow = OnboardingFlow.SignUp))
fakeHomeServerConnectionConfigFactory.givenConfigFor(A_HOMESERVER_URL, A_HOMESERVER_CONFIG) fakeHomeServerConnectionConfigFactory.givenConfigFor(A_HOMESERVER_URL, A_HOMESERVER_CONFIG)
fakeStartAuthenticationFlowUseCase.givenResult(A_HOMESERVER_CONFIG, StartAuthenticationResult(isHomeserverOutdated = false, SELECTED_HOMESERVER_STATE)) fakeStartAuthenticationFlowUseCase.givenResult(A_HOMESERVER_CONFIG, StartAuthenticationResult(isHomeserverOutdated = false, SELECTED_HOMESERVER_STATE))
givenRegistrationResultFor(RegisterAction.StartRegistration, RegistrationResult.NextStep(AN_IGNORED_FLOW_RESULT)) givenRegistrationResultFor(RegisterAction.StartRegistration, ANY_CONTINUING_REGISTRATION_RESULT)
fakeHomeServerHistoryService.expectUrlToBeAdded(A_HOMESERVER_CONFIG.homeServerUri.toString()) fakeHomeServerHistoryService.expectUrlToBeAdded(A_HOMESERVER_CONFIG.homeServerUri.toString())
val test = viewModel.test() val test = viewModel.test()
@ -318,11 +315,11 @@ class OnboardingViewModelTest {
@Test @Test
fun `given personalisation enabled, when registering account, then updates state and emits account created event`() = runTest { fun `given personalisation enabled, when registering account, then updates state and emits account created event`() = runTest {
fakeVectorFeatures.givenPersonalisationEnabled() fakeVectorFeatures.givenPersonalisationEnabled()
givenRegistrationResultFor(A_LOADABLE_REGISTER_ACTION, RegistrationResult.Complete(fakeSession))
givenSuccessfullyCreatesAccount(A_HOMESERVER_CAPABILITIES) givenSuccessfullyCreatesAccount(A_HOMESERVER_CAPABILITIES)
givenRegistrationResultFor(RegisterAction.StartRegistration, RegistrationActionHandler.Result.Success(fakeSession))
val test = viewModel.test() val test = viewModel.test()
viewModel.handle(OnboardingAction.PostRegisterAction(A_LOADABLE_REGISTER_ACTION)) viewModel.handle(OnboardingAction.PostRegisterAction(RegisterAction.StartRegistration))
test test
.assertStatesChanges( .assertStatesChanges(
@ -334,26 +331,6 @@ class OnboardingViewModelTest {
.finish() .finish()
} }
@Test
fun `given personalisation enabled and registration has started and has dummy step to do, when handling action, then ignores other steps and does dummy`() {
runTest {
fakeVectorFeatures.givenPersonalisationEnabled()
givenSuccessfulRegistrationForStartAndDummySteps(missingStages = listOf(Stage.Dummy(mandatory = true)))
val test = viewModel.test()
viewModel.handle(OnboardingAction.PostRegisterAction(A_LOADABLE_REGISTER_ACTION))
test
.assertStatesChanges(
initialState,
{ copy(isLoading = true) },
{ copy(isLoading = false, personalizationState = A_HOMESERVER_CAPABILITIES.toPersonalisationState()) }
)
.assertEvents(OnboardingViewEvents.OnAccountCreated)
.finish()
}
}
@Test @Test
fun `given changing profile avatar is supported, when updating display name, then updates upstream user display name and moves to choose profile avatar`() { fun `given changing profile avatar is supported, when updating display name, then updates upstream user display name and moves to choose profile avatar`() {
runTest { runTest {
@ -520,11 +497,11 @@ class OnboardingViewModelTest {
fakeVectorFeatures, fakeVectorFeatures,
FakeAnalyticsTracker(), FakeAnalyticsTracker(),
fakeUriFilenameResolver.instance, fakeUriFilenameResolver.instance,
fakeRegisterActionHandler.instance,
fakeDirectLoginUseCase.instance, fakeDirectLoginUseCase.instance,
fakeStartAuthenticationFlowUseCase.instance, fakeStartAuthenticationFlowUseCase.instance,
FakeVectorOverrides(), FakeVectorOverrides(),
aBuildMeta() fakeRegistrationActionHandler.instance,
aBuildMeta(),
).also { ).also {
viewModel = it viewModel = it
initialState = state initialState = state
@ -556,17 +533,6 @@ class OnboardingViewModelTest {
) )
} }
private fun givenSuccessfulRegistrationForStartAndDummySteps(missingStages: List<Stage>) {
val flowResult = FlowResult(missingStages = missingStages, completedStages = emptyList())
givenRegistrationResultsFor(
listOf(
A_LOADABLE_REGISTER_ACTION to RegistrationResult.NextStep(flowResult),
RegisterAction.RegisterDummy to RegistrationResult.Complete(fakeSession)
)
)
givenSuccessfullyCreatesAccount(A_HOMESERVER_CAPABILITIES)
}
private fun givenSuccessfullyCreatesAccount(homeServerCapabilities: HomeServerCapabilities) { private fun givenSuccessfullyCreatesAccount(homeServerCapabilities: HomeServerCapabilities) {
fakeSession.fakeHomeServerCapabilitiesService.givenCapabilities(homeServerCapabilities) fakeSession.fakeHomeServerCapabilitiesService.givenCapabilities(homeServerCapabilities)
givenInitialisesSession(fakeSession) givenInitialisesSession(fakeSession)
@ -578,21 +544,16 @@ class OnboardingViewModelTest {
fakeSession.expectStartsSyncing() fakeSession.expectStartsSyncing()
} }
private fun givenRegistrationResultFor(action: RegisterAction, result: RegistrationResult) { private fun givenRegistrationResultFor(action: RegisterAction, result: RegistrationActionHandler.Result) {
givenRegistrationResultsFor(listOf(action to result)) givenRegistrationResultsFor(listOf(action to result))
} }
private fun givenRegistrationResultsFor(results: List<Pair<RegisterAction, RegistrationResult>>) { private fun givenRegistrationResultsFor(results: List<Pair<RegisterAction, RegistrationActionHandler.Result>>) {
fakeAuthenticationService.givenRegistrationStarted(true) fakeRegistrationActionHandler.givenResultsFor(results)
val registrationWizard = FakeRegistrationWizard()
fakeAuthenticationService.givenRegistrationWizard(registrationWizard)
fakeRegisterActionHandler.givenResultsFor(registrationWizard, results)
} }
private fun givenRegistrationActionErrors(action: RegisterAction, cause: Throwable) { private fun givenRegistrationActionErrors(action: RegisterAction, cause: Throwable) {
val registrationWizard = FakeRegistrationWizard() fakeRegistrationActionHandler.givenThrows(action, cause)
fakeAuthenticationService.givenRegistrationWizard(registrationWizard)
fakeRegisterActionHandler.givenThrowsFor(registrationWizard, action, cause)
} }
} }

View File

@ -16,108 +16,166 @@
package im.vector.app.features.onboarding package im.vector.app.features.onboarding
import im.vector.app.test.fakes.FakeRegistrationWizard import im.vector.app.R
import im.vector.app.test.fixtures.SelectedHomeserverStateFixture.aSelectedHomeserverState
import im.vector.app.test.fakes.FakeAuthenticationService
import im.vector.app.test.fakes.FakeRegistrationWizardActionDelegate
import im.vector.app.test.fakes.FakeSession import im.vector.app.test.fakes.FakeSession
import im.vector.app.test.fixtures.a401ServerError import im.vector.app.test.fakes.FakeStringProvider
import io.mockk.coVerifyAll import im.vector.app.test.fakes.FakeVectorFeatures
import im.vector.app.test.fakes.FakeVectorOverrides
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo import org.amshove.kluent.shouldBeEqualTo
import org.junit.Test import org.junit.Test
import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegisterThreePid import org.matrix.android.sdk.api.auth.registration.RegisterThreePid
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.auth.registration.RegistrationResult as SdkResult
private const val IGNORED_DELAY = 0L
private val AN_ERROR = RuntimeException()
private val A_SESSION = FakeSession() private val A_SESSION = FakeSession()
private val AN_EXPECTED_RESULT = RegistrationResult.Complete(A_SESSION)
private const val A_USERNAME = "a username"
private const val A_PASSWORD = "a password"
private const val AN_INITIAL_DEVICE_NAME = "a device name"
private const val A_CAPTCHA_RESPONSE = "a captcha response"
private const val A_PID_CODE = "a pid code"
private const val EMAIL_VALIDATED_DELAY = 10000L
private val A_PID_TO_REGISTER = RegisterThreePid.Email("an email")
class RegistrationActionHandlerTest { class RegistrationActionHandlerTest {
private val fakeRegistrationWizard = FakeRegistrationWizard() private val fakeWizardActionDelegate = FakeRegistrationWizardActionDelegate()
private val registrationActionHandler = RegistrationActionHandler() private val fakeAuthenticationService = FakeAuthenticationService()
private val vectorOverrides = FakeVectorOverrides()
private val vectorFeatures = FakeVectorFeatures()
private val fakeStringProvider = FakeStringProvider().also {
it.given(R.string.matrix_org_server_url, "https://matrix.org")
}
private val registrationActionHandler = RegistrationActionHandler(
fakeWizardActionDelegate.instance,
fakeAuthenticationService,
vectorOverrides,
vectorFeatures,
fakeStringProvider.instance
)
@Test @Test
fun `when handling register action then delegates to wizard`() = runTest { fun `when processing SendAgainThreePid, then ignores result`() = runTest {
val cases = listOf( val sendAgainThreePid = RegisterAction.SendAgainThreePid
case(RegisterAction.StartRegistration) { getRegistrationFlow() }, fakeWizardActionDelegate.givenResultsFor(listOf(sendAgainThreePid to RegistrationResult.Complete(A_SESSION)))
case(RegisterAction.CaptchaDone(A_CAPTCHA_RESPONSE)) { performReCaptcha(A_CAPTCHA_RESPONSE) },
case(RegisterAction.AcceptTerms) { acceptTerms() },
case(RegisterAction.RegisterDummy) { dummy() },
case(RegisterAction.AddThreePid(A_PID_TO_REGISTER)) { addThreePid(A_PID_TO_REGISTER) },
case(RegisterAction.SendAgainThreePid) { sendAgainThreePid() },
case(RegisterAction.ValidateThreePid(A_PID_CODE)) { handleValidateThreePid(A_PID_CODE) },
case(RegisterAction.CheckIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY)) { checkIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY) },
case(RegisterAction.CreateAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)) {
createAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)
}
)
cases.forEach { testSuccessfulActionDelegation(it) } val result = registrationActionHandler.processAction(sendAgainThreePid)
result shouldBeEqualTo RegistrationActionHandler.Result.Ignored
} }
@Test @Test
fun `given adding an email ThreePid fails with 401, when handling register action, then infer EmailSuccess`() = runTest { fun `given wizard delegate returns success, when handling action, then returns success`() = runTest {
fakeRegistrationWizard.givenAddEmailThreePidErrors( fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to RegistrationResult.Complete(A_SESSION)))
cause = a401ServerError(),
email = A_PID_TO_REGISTER.email
)
val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.AddThreePid(A_PID_TO_REGISTER)) val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationResult.SendEmailSuccess(A_PID_TO_REGISTER.email) result shouldBeEqualTo RegistrationActionHandler.Result.Success(A_SESSION)
} }
@Test @Test
fun `given email verification errors with 401 then fatal error, when checking email validation, then continues to poll until non 401 error`() = runTest { fun `given flow result contains unsupported stages, when handling action, then returns UnsupportedStage`() = runTest {
val errorsToThrow = listOf( fakeAuthenticationService.givenRegistrationStarted(false)
a401ServerError(), fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to anUnsupportedResult()))
a401ServerError(),
a401ServerError(),
AN_ERROR
)
fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow)
val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size) result shouldBeEqualTo RegistrationActionHandler.Result.UnsupportedStage
result shouldBeEqualTo RegistrationResult.Error(AN_ERROR)
} }
@Test @Test
fun `given email verification errors with 401 and succeeds, when checking email validation, then continues to poll until success`() = runTest { fun `given flow result with mandatory and optional stages, when handling action, then returns mandatory stage`() = runTest {
val errorsToThrow = listOf( val mandatoryStage = Stage.ReCaptcha(mandatory = true, "ignored-key")
a401ServerError(), val mixedStages = listOf(Stage.Email(mandatory = false), mandatoryStage)
a401ServerError(), givenFlowResult(mixedStages)
a401ServerError()
val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(mandatoryStage)
}
@Test
fun `given flow result with only optional stages, when handling action, then returns optional stage`() = runTest {
val optionalStage = Stage.ReCaptcha(mandatory = false, "ignored-key")
givenFlowResult(listOf(optionalStage))
val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(optionalStage)
}
@Test
fun `given flow result with missing stages, when handling action, then returns MissingNextStage`() = runTest {
givenFlowResult(emptyList())
val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.MissingNextStage
}
@Test
fun `given flow result with only optional dummy stage, when handling action, then returns MissingNextStage`() = runTest {
givenFlowResult(listOf(Stage.Dummy(mandatory = false)))
val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.MissingNextStage
}
@Test
fun `given non matrix org homeserver and flow result with missing mandatory stages, when handling action, then returns first item`() = runTest {
val firstStage = Stage.ReCaptcha(mandatory = true, "ignored-key")
val orderedStages = listOf(firstStage, Stage.Email(mandatory = true), Stage.Msisdn(mandatory = true))
givenFlowResult(orderedStages)
val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(firstStage)
}
@Test
fun `given matrix org homeserver and flow result with missing mandatory stages, when handling action, then returns email item first`() = runTest {
vectorFeatures.givenCombinedRegisterEnabled()
val expectedFirstItem = Stage.Email(mandatory = true)
val orderedStages = listOf(Stage.ReCaptcha(mandatory = true, "ignored-key"), expectedFirstItem, Stage.Msisdn(mandatory = true))
givenFlowResult(orderedStages)
val result = registrationActionHandler.processAction(state = aSelectedHomeserverState("https://matrix.org/"), RegisterAction.StartRegistration)
result shouldBeEqualTo RegistrationActionHandler.Result.NextStage(expectedFirstItem)
}
@Test
fun `given password already sent and missing mandatory dummy stage, when handling action, then fast tracks the dummy stage`() = runTest {
val stages = listOf(Stage.ReCaptcha(mandatory = true, "ignored-key"), Stage.Email(mandatory = true), Stage.Dummy(mandatory = true))
fakeAuthenticationService.givenRegistrationStarted(true)
fakeWizardActionDelegate.givenResultsFor(
listOf(
RegisterAction.StartRegistration to aFlowResult(stages),
RegisterAction.RegisterDummy to RegistrationResult.Complete(A_SESSION)
)
) )
fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow, finally = SdkResult.Success(A_SESSION))
val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY)) val result = registrationActionHandler.processAction(RegisterAction.StartRegistration)
fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size + 1) result shouldBeEqualTo RegistrationActionHandler.Result.Success(A_SESSION)
result shouldBeEqualTo RegistrationResult.Complete(A_SESSION)
} }
private suspend fun testSuccessfulActionDelegation(case: Case) { private fun givenFlowResult(stages: List<Stage>) {
val fakeRegistrationWizard = FakeRegistrationWizard() fakeAuthenticationService.givenRegistrationStarted(true)
val registrationActionHandler = RegistrationActionHandler() fakeWizardActionDelegate.givenResultsFor(listOf(RegisterAction.StartRegistration to aFlowResult(stages)))
fakeRegistrationWizard.givenSuccessFor(result = A_SESSION, case.expect)
val result = registrationActionHandler.handleRegisterAction(fakeRegistrationWizard, case.action)
coVerifyAll { case.expect(fakeRegistrationWizard) }
result shouldBeEqualTo AN_EXPECTED_RESULT
} }
private fun aFlowResult(missingStages: List<Stage>) = RegistrationResult.NextStep(
FlowResult(
missingStages = missingStages,
completedStages = emptyList()
)
)
private fun anUnsupportedResult() = RegistrationResult.NextStep(
FlowResult(
missingStages = listOf(Stage.Other(mandatory = true, "ignored-type", emptyMap<String, String>())),
completedStages = emptyList()
)
)
private suspend fun RegistrationActionHandler.processAction(action: RegisterAction) = processAction(aSelectedHomeserverState(), action)
} }
private fun case(action: RegisterAction, expect: suspend RegistrationWizard.() -> SdkResult) = Case(action, expect)
private class Case(val action: RegisterAction, val expect: suspend RegistrationWizard.() -> SdkResult)

View File

@ -0,0 +1,128 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.features.onboarding
import im.vector.app.test.fakes.FakeAuthenticationService
import im.vector.app.test.fakes.FakeRegistrationWizard
import im.vector.app.test.fakes.FakeSession
import im.vector.app.test.fixtures.a401ServerError
import io.mockk.coVerifyAll
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo
import org.junit.Test
import org.matrix.android.sdk.api.auth.registration.RegisterThreePid
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
import org.matrix.android.sdk.api.auth.registration.RegistrationResult as MatrixRegistrationResult
private const val IGNORED_DELAY = 0L
private val AN_ERROR = RuntimeException()
private val A_SESSION = FakeSession()
private val AN_EXPECTED_RESULT = RegistrationResult.Complete(A_SESSION)
private const val A_USERNAME = "a username"
private const val A_PASSWORD = "a password"
private const val AN_INITIAL_DEVICE_NAME = "a device name"
private const val A_CAPTCHA_RESPONSE = "a captcha response"
private const val A_PID_CODE = "a pid code"
private const val EMAIL_VALIDATED_DELAY = 10000L
private val A_PID_TO_REGISTER = RegisterThreePid.Email("an email")
class RegistrationWizardActionDelegateTest {
private val fakeRegistrationWizard = FakeRegistrationWizard()
private val fakeAuthenticationService = FakeAuthenticationService().also {
it.givenRegistrationWizard(fakeRegistrationWizard)
}
private val registrationActionHandler = RegistrationWizardActionDelegate(fakeAuthenticationService)
@Test
fun `when handling register action then delegates to wizard`() = runTest {
val cases = listOf(
case(RegisterAction.StartRegistration) { getRegistrationFlow() },
case(RegisterAction.CaptchaDone(A_CAPTCHA_RESPONSE)) { performReCaptcha(A_CAPTCHA_RESPONSE) },
case(RegisterAction.AcceptTerms) { acceptTerms() },
case(RegisterAction.RegisterDummy) { dummy() },
case(RegisterAction.AddThreePid(A_PID_TO_REGISTER)) { addThreePid(A_PID_TO_REGISTER) },
case(RegisterAction.SendAgainThreePid) { sendAgainThreePid() },
case(RegisterAction.ValidateThreePid(A_PID_CODE)) { handleValidateThreePid(A_PID_CODE) },
case(RegisterAction.CheckIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY)) { checkIfEmailHasBeenValidated(EMAIL_VALIDATED_DELAY) },
case(RegisterAction.CreateAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)) {
createAccount(A_USERNAME, A_PASSWORD, AN_INITIAL_DEVICE_NAME)
}
)
cases.forEach { testSuccessfulActionDelegation(it) }
}
@Test
fun `given adding an email ThreePid fails with 401, when handling register action, then infer EmailSuccess`() = runTest {
fakeRegistrationWizard.givenAddEmailThreePidErrors(
cause = a401ServerError(),
email = A_PID_TO_REGISTER.email
)
val result = registrationActionHandler.executeAction(RegisterAction.AddThreePid(A_PID_TO_REGISTER))
result shouldBeEqualTo RegistrationResult.SendEmailSuccess(A_PID_TO_REGISTER.email)
}
@Test
fun `given email verification errors with 401 then fatal error, when checking email validation, then continues to poll until non 401 error`() = runTest {
val errorsToThrow = listOf(
a401ServerError(),
a401ServerError(),
a401ServerError(),
AN_ERROR
)
fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow)
val result = registrationActionHandler.executeAction(RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY))
fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size)
result shouldBeEqualTo RegistrationResult.Error(AN_ERROR)
}
@Test
fun `given email verification errors with 401 and succeeds, when checking email validation, then continues to poll until success`() = runTest {
val errorsToThrow = listOf(
a401ServerError(),
a401ServerError(),
a401ServerError()
)
fakeRegistrationWizard.givenCheckIfEmailHasBeenValidatedErrors(errorsToThrow, finally = MatrixRegistrationResult.Success(A_SESSION))
val result = registrationActionHandler.executeAction(RegisterAction.CheckIfEmailHasBeenValidated(IGNORED_DELAY))
fakeRegistrationWizard.verifyCheckedEmailedVerification(times = errorsToThrow.size + 1)
result shouldBeEqualTo RegistrationResult.Complete(A_SESSION)
}
private suspend fun testSuccessfulActionDelegation(case: Case) {
val fakeRegistrationWizard = FakeRegistrationWizard()
val authenticationService = FakeAuthenticationService().also { it.givenRegistrationWizard(fakeRegistrationWizard) }
val registrationActionHandler = RegistrationWizardActionDelegate(authenticationService)
fakeRegistrationWizard.givenSuccessFor(result = A_SESSION, case.expect)
val result = registrationActionHandler.executeAction(case.action)
coVerifyAll { case.expect(fakeRegistrationWizard) }
result shouldBeEqualTo AN_EXPECTED_RESULT
}
}
private fun case(action: RegisterAction, expect: suspend RegistrationWizard.() -> MatrixRegistrationResult) = Case(action, expect)
private class Case(val action: RegisterAction, val expect: suspend RegistrationWizard.() -> MatrixRegistrationResult)

View File

@ -18,23 +18,21 @@ package im.vector.app.test.fakes
import im.vector.app.features.onboarding.RegisterAction import im.vector.app.features.onboarding.RegisterAction
import im.vector.app.features.onboarding.RegistrationActionHandler import im.vector.app.features.onboarding.RegistrationActionHandler
import im.vector.app.features.onboarding.RegistrationResult
import io.mockk.coEvery import io.mockk.coEvery
import io.mockk.mockk import io.mockk.mockk
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
class FakeRegisterActionHandler { class FakeRegistrationActionHandler {
val instance = mockk<RegistrationActionHandler>() val instance = mockk<RegistrationActionHandler>()
fun givenResultsFor(wizard: RegistrationWizard, result: List<Pair<RegisterAction, RegistrationResult>>) { fun givenThrows(action: RegisterAction, cause: Throwable) {
coEvery { instance.handleRegisterAction(wizard, any()) } answers { call -> coEvery { instance.processAction(any(), action) } throws cause
}
fun givenResultsFor(result: List<Pair<RegisterAction, RegistrationActionHandler.Result>>) {
coEvery { instance.processAction(any(), any()) } answers { call ->
val actionArg = call.invocation.args[1] as RegisterAction val actionArg = call.invocation.args[1] as RegisterAction
result.first { it.first == actionArg }.second result.first { it.first == actionArg }.second
} }
} }
fun givenThrowsFor(wizard: RegistrationWizard, action: RegisterAction, cause: Throwable) {
coEvery { instance.handleRegisterAction(wizard, action) } throws cause
}
} }

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.test.fakes
import im.vector.app.features.onboarding.RegisterAction
import im.vector.app.features.onboarding.RegistrationResult
import im.vector.app.features.onboarding.RegistrationWizardActionDelegate
import io.mockk.coEvery
import io.mockk.mockk
class FakeRegistrationWizardActionDelegate {
val instance = mockk<RegistrationWizardActionDelegate>()
fun givenResultsFor(result: List<Pair<RegisterAction, RegistrationResult>>) {
coEvery { instance.executeAction(any()) } answers { call ->
val actionArg = call.invocation.args[0] as RegisterAction
result.first { it.first == actionArg }.second
}
}
fun givenThrowsFor(action: RegisterAction, cause: Throwable) {
coEvery { instance.executeAction(action) } throws cause
}
}

View File

@ -26,4 +26,8 @@ class FakeVectorFeatures : VectorFeatures by spyk<DefaultVectorFeatures>() {
fun givenPersonalisationEnabled() { fun givenPersonalisationEnabled() {
every { isOnboardingPersonalizeEnabled() } returns true every { isOnboardingPersonalizeEnabled() } returns true
} }
fun givenCombinedRegisterEnabled() {
every { isOnboardingCombinedRegisterEnabled() } returns true
}
} }

View File

@ -0,0 +1,26 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.test.fixtures
import im.vector.app.features.onboarding.SelectedHomeserverState
object SelectedHomeserverStateFixture {
fun aSelectedHomeserverState(
userFacingUrl: String = "https://myhomeserver.com",
) = SelectedHomeserverState(userFacingUrl = userFacingUrl)
}