diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt index 8e0620507f..6740bac642 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt @@ -24,6 +24,7 @@ import com.squareup.moshi.Types import dagger.Lazy import java.io.File import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.ConcurrentHashMap import javax.inject.Inject import kotlin.jvm.Throws import kotlin.math.max @@ -32,6 +33,8 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.cancelChildren import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext import org.matrix.android.sdk.api.MatrixCallback import org.matrix.android.sdk.api.NoOpMatrixCallback @@ -152,6 +155,11 @@ internal class DefaultCryptoService @Inject constructor( private var olmMachine: OlmMachine? = null private val deviceObserver: DeviceUpdateObserver = DeviceUpdateObserver() + // Locks for some of our operations + private val keyClaimLock: Mutex = Mutex() + private val outgointRequestsLock: Mutex = Mutex() + private val roomKeyShareLocks: ConcurrentHashMap = ConcurrentHashMap() + suspend fun onStateEvent(roomId: String, event: Event) { when (event.getClearType()) { EventType.STATE_ROOM_ENCRYPTION -> onRoomEncryptionEvent(roomId, event) @@ -650,23 +658,28 @@ internal class DefaultCryptoService @Inject constructor( } private suspend fun preshareGroupSession(roomId: String, roomMembers: List) { - // TODO this needs to be locked per room - val request = olmMachine!!.getMissingSessions(roomMembers) - - if (request != null) { - // This request can only be a keys claim request. - when (request) { - is Request.KeysClaim -> { - claimKeys(request) + keyClaimLock.withLock { + val request = olmMachine!!.getMissingSessions(roomMembers) + if (request != null) { + // This request can only be a keys claim request. + when (request) { + is Request.KeysClaim -> { + claimKeys(request) + } } } } - for (toDeviceRequest in olmMachine!!.shareGroupSession(roomId, roomMembers)) { - // This request can only be a to-device request. - when (toDeviceRequest) { - is Request.ToDevice -> { - sendToDevice(toDeviceRequest) + val keyShareLock = roomKeyShareLocks.getOrDefault(roomId, Mutex()) + + keyShareLock.withLock { + for (toDeviceRequest in olmMachine!!.shareGroupSession(roomId, roomMembers)) { + // TODO these requests should be sent out in parallel + // This request can only be a to-device request. + when (toDeviceRequest) { + is Request.ToDevice -> { + sendToDevice(toDeviceRequest) + } } } } @@ -699,7 +712,6 @@ internal class DefaultCryptoService @Inject constructor( } private suspend fun queryKeys(outgoingRequest: Request.KeysQuery) { - Timber.v("HELLO KEYS QUERY REQUEST ${outgoingRequest.users}") val params = DownloadKeysForUsersTask.Params(outgoingRequest.users, null) try { @@ -729,7 +741,6 @@ internal class DefaultCryptoService @Inject constructor( } private suspend fun claimKeys(request: Request.KeysClaim) { - // TODO this needs to be locked per call val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(request.oneTimeKeys) val response = oneTimeKeysForUsersDeviceTask.execute(claimParams) val adapter = MoshiProvider @@ -741,18 +752,19 @@ internal class DefaultCryptoService @Inject constructor( } private suspend fun sendOutgoingRequests() { - // TODO this needs to be locked per call - // TODO these requests should be sent out in parallel - for (outgoingRequest in olmMachine!!.outgoingRequests()) { - when (outgoingRequest) { - is Request.KeysUpload -> { - uploadKeys(outgoingRequest) - } - is Request.KeysQuery -> { - queryKeys(outgoingRequest) - } - is Request.ToDevice -> { - // Timber.v("HELLO TO DEVICE REQUEST ${outgoingRequest.body}") + outgointRequestsLock.withLock { + // TODO these requests should be sent out in parallel + for (outgoingRequest in olmMachine!!.outgoingRequests()) { + when (outgoingRequest) { + is Request.KeysUpload -> { + uploadKeys(outgoingRequest) + } + is Request.KeysQuery -> { + queryKeys(outgoingRequest) + } + is Request.ToDevice -> { + // Timber.v("HELLO TO DEVICE REQUEST ${outgoingRequest.body}") + } } } }