blob: 4f1df7268856d408d60d27286b6f9d8634d1b6df [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.android.connecteddevice.connection
import android.content.Context
import android.os.ParcelUuid
import android.util.Base64
import androidx.room.Room
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.android.companionprotos.VerificationCode
import com.google.android.companionprotos.VerificationCodeState
import com.google.android.connecteddevice.connection.MultiProtocolSecureChannel.ChannelError
import com.google.android.connecteddevice.connection.MultiProtocolSecureChannel.MessageError
import com.google.android.connecteddevice.connection.MultiProtocolSecureChannel.ShowVerificationCodeListener
import com.google.android.connecteddevice.model.DeviceMessage
import com.google.android.connecteddevice.model.DeviceMessage.OperationType
import com.google.android.connecteddevice.oob.OobRunner
import com.google.android.connecteddevice.storage.ConnectedDeviceDatabase
import com.google.android.connecteddevice.storage.ConnectedDeviceStorage
import com.google.android.connecteddevice.storage.CryptoHelper
import com.google.android.connecteddevice.transport.ConnectChallenge
import com.google.android.connecteddevice.transport.ConnectionProtocol
import com.google.android.connecteddevice.transport.IDataSendCallback
import com.google.android.connecteddevice.transport.IDiscoveryCallback
import com.google.android.connecteddevice.transport.ProtocolDevice
import com.google.android.connecteddevice.util.ByteUtils
import com.google.android.encryptionrunner.EncryptionRunnerFactory
import com.google.android.encryptionrunner.FakeEncryptionRunner
import com.google.android.encryptionrunner.HandshakeException
import com.google.common.truth.Truth.assertThat
import com.google.common.util.concurrent.MoreExecutors.directExecutor
import com.google.protobuf.ByteString
import com.nhaarman.mockitokotlin2.any
import com.nhaarman.mockitokotlin2.argumentCaptor
import com.nhaarman.mockitokotlin2.eq
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.never
import com.nhaarman.mockitokotlin2.spy
import com.nhaarman.mockitokotlin2.times
import com.nhaarman.mockitokotlin2.verify
import com.nhaarman.mockitokotlin2.whenever
import java.security.SignatureException
import java.util.UUID
import java.util.zip.DataFormatException
import java.util.zip.Inflater
import org.junit.Assert.fail
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
private const val PROTOCOL_ID_1 = "testProtocol1"
private const val PROTOCOL_ID_2 = "testProtocol2"
private val SERVER_DEVICE_ID = UUID.fromString("a29f0c74-2014-4b14-ac02-be6ed15b545a")
@RunWith(AndroidJUnit4::class)
class MultiProtocolSecureChannelTest {
private val context = ApplicationProvider.getApplicationContext<Context>()
private val stream1 = spy(ProtocolStream(ProtocolDevice(TestProtocol(), PROTOCOL_ID_1)))
private val stream2 = spy(ProtocolStream(ProtocolDevice(TestProtocol(), PROTOCOL_ID_2)))
private val mockInflater: Inflater = mock()
private val mockCallback: MultiProtocolSecureChannel.Callback = mock()
private val mockOobRunner: OobRunner = mock()
private lateinit var secureChannel: MultiProtocolSecureChannel
private lateinit var spyStorage: ConnectedDeviceStorage
private val mockShowVerificationCodeListener: ShowVerificationCodeListener = mock()
@Before
@Throws(SignatureException::class)
fun setUp() {
val database =
Room.inMemoryDatabaseBuilder(context, ConnectedDeviceDatabase::class.java)
.allowMainThreadQueries()
.setQueryExecutor(directExecutor())
.build()
.associatedDeviceDao()
spyStorage =
spy(ConnectedDeviceStorage(context, Base64CryptoHelper(), database, directExecutor()))
whenever(spyStorage.uniqueId).thenReturn(SERVER_DEVICE_ID)
}
@Test
@Throws(SignatureException::class)
fun decryptMessage_doesNothingForUnencryptedMessage() {
val testPayload = ByteUtils.randomBytes(10)
completeHandshakeAndSaveTheKey()
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
testPayload
)
secureChannel.decryptMessage(message)
assertThat(message.message).isEqualTo(testPayload)
}
@Test
@Throws(SignatureException::class)
fun decryptMessage_decryptsEncryptedMessage() {
completeHandshakeAndSaveTheKey()
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
assertThat(secureChannel.decryptMessage(message)).isTrue()
}
@Test
@Throws(InterruptedException::class)
fun decryptMessage_onMessageReceivedErrorForEncryptedMessageWithNoKey() {
setupSecureChannel(true)
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
assertThat(secureChannel.decryptMessage(message)).isFalse()
verify(mockCallback).onMessageReceivedError(MessageError.MESSAGE_ERROR_DECRYPTION_FAILURE)
assertThat(message.message).isEmpty()
}
@Test
fun onDeviceMessageReceived_onEstablishSecureChannelFailureBadHandshakeMessage() {
setupSecureChannel(true)
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.ENCRYPTION_HANDSHAKE,
ByteUtils.randomBytes(10)
)
secureChannel.onDeviceMessageReceived(message)
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_MSG)
}
@Test
fun onDeviceMessageReceived_onMessageReceivedNotIssuedForEmptyMessage() {
completeHandshakeAndSaveTheKey()
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
/* message= */ byteArrayOf()
)
secureChannel.onDeviceMessageReceived(message)
verify(mockCallback, never()).onMessageReceived(any())
}
@Test
fun onDeviceMessageReceived_processHandshakeExceptionIssuesSecureChannelFailureCallback() {
setupSecureChannel(true)
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
FakeEncryptionRunner.INIT_MESSAGE
)
whenever(secureChannel.processHandshake(message.message)).then {
throw HandshakeException("test")
}
secureChannel.onDeviceMessageReceived(message)
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_HANDSHAKE)
}
@Test
fun onDeviceMessageReceived_processClientMessageIssuesMessageReceivedErrorCallback() {
setupSecureChannel(true)
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
secureChannel.onDeviceMessageReceived(message)
verify(mockCallback).onMessageReceivedError(MessageError.MESSAGE_ERROR_DECRYPTION_FAILURE)
}
@Test
fun onDeviceMessageReceived_noExceptionWhenReceivedUnknownMessageType() {
completeHandshakeAndSaveTheKey()
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.UNKNOWN,
ByteUtils.randomBytes(10)
)
try {
secureChannel.onDeviceMessageReceived(message)
} catch (e: Exception) {
fail("Should not throw exception.")
}
}
@Test
fun onDeviceMessageReceived_issueOnMessageReceivedWhenHandshakeFinished() {
completeHandshakeAndSaveTheKey()
val message =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.ENCRYPTION_HANDSHAKE,
ByteUtils.randomBytes(10)
)
secureChannel.onDeviceMessageReceived(message)
verify(mockCallback).onMessageReceived(message)
}
@Test
fun decompressMessage_returnsOriginalMessageIfOriginalSizeIsZero() {
val testPayload = ByteUtils.randomBytes(10)
completeHandshakeAndSaveTheKey()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
testPayload
)
deviceMessage.originalMessageSize = 0
assertThat(secureChannel.decompressMessage(deviceMessage)).isTrue()
assertThat(deviceMessage.message).isEqualTo(testPayload)
}
@Test
fun decompressMessage_returnsFalseWhenThereAreDataFormatException() {
completeHandshakeAndSaveTheKey()
whenever(mockInflater.inflate(any())).then { throw DataFormatException() }
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
secureChannel.callback = mockCallback
deviceMessage.originalMessageSize = 1
assertThat(secureChannel.decompressMessage(deviceMessage)).isFalse()
verify(mockCallback).onMessageReceivedError(MessageError.MESSAGE_ERROR_DECOMPRESSION_FAILURE)
}
@Test
fun compressMessage_returnsCompressedMessageWithOriginalSize() {
completeHandshakeAndSaveTheKey()
val message = ByteArray(100)
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
message
)
secureChannel.compressMessage(deviceMessage)
assertThat(deviceMessage.message).isNotEqualTo(message)
assertThat(deviceMessage.originalMessageSize).isEqualTo(message.size)
}
@Test
fun compressedMessageCanBeDecompressed() {
completeHandshakeAndSaveTheKey()
val message = ByteArray(100)
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ false,
OperationType.CLIENT_MESSAGE,
message
)
secureChannel.compressMessage(deviceMessage)
assertThat(secureChannel.decompressMessage(deviceMessage)).isTrue()
assertThat(deviceMessage.message).isEqualTo(message)
}
@Test
fun addStream_notifyCallbackWhenMessageReceived() {
completeHandshakeAndSaveTheKey()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
secureChannel.addStream(stream2)
assertThat(stream2.messageReceivedListener).isNotNull()
stream2.messageReceivedListener!!.onMessageReceived(deviceMessage)
verify(mockCallback).onMessageReceived(any())
}
@Test
fun addStream_notifyCallbackWhenProtocolDisconnected() {
setupSecureChannel(true)
assertThat(stream1.protocolDisconnectListener).isNotNull()
stream1.protocolDisconnectListener!!.onProtocolDisconnected()
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_DEVICE_DISCONNECTED)
}
@Test
fun addStream_doNotNotifyCallbackWhenProtocolDisconnected() {
setupSecureChannel(true)
secureChannel.addStream(stream2)
assertThat(stream2.protocolDisconnectListener).isNotNull()
stream2.protocolDisconnectListener!!.onProtocolDisconnected()
verify(mockCallback, never())
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_DEVICE_DISCONNECTED)
}
@Test
fun sendClientMessage_FailedToSendMessageWithNullKey() {
setupSecureChannel(true)
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
assertThat(secureChannel.sendClientMessage(deviceMessage)).isFalse()
}
@Test
fun sendClientMessage_FailedToSendMessageWithEmptyStream() {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
completeHandshakeAndSaveTheKey()
stream1.protocolDisconnectListener!!.onProtocolDisconnected()
assertThat(secureChannel.sendClientMessage(deviceMessage)).isFalse()
}
@Test
fun sendClientMessage_successfullySendMessage() {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
ByteUtils.randomBytes(10)
)
completeHandshakeAndSaveTheKey()
assertThat(secureChannel.sendClientMessage(deviceMessage)).isTrue()
}
@Test
fun sendClientMessage_successfullyEncryptAndSetMessage() {
val testPayload = ByteUtils.randomBytes(10)
val deviceMessage =
DeviceMessage.createOutgoingMessage(
UUID.randomUUID(),
/* isMessageEncrypted= */ true,
OperationType.CLIENT_MESSAGE,
testPayload
)
completeHandshakeAndSaveTheKey()
secureChannel.sendClientMessage(deviceMessage)
assertThat(deviceMessage.message).isNotEqualTo(testPayload)
}
@Test
fun association_secureChannelEstablishedSuccessfully() {
val clientId = UUID.randomUUID()
setupSecureChannel(false)
secureChannel.showVerificationCodeListener = mockShowVerificationCodeListener
argumentCaptor<DeviceMessage>().apply {
initHandshakeMessage()
verify(stream1).sendMessage(capture())
val response = firstValue.message
assertThat(response).isEqualTo(FakeEncryptionRunner.INIT_RESPONSE)
}
respondToContinueMessage()
val testVerificationCodeMessage =
VerificationCode.newBuilder().setState(VerificationCodeState.VISUAL_VERIFICATION).build()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
verify(mockShowVerificationCodeListener).showVerificationCode(any())
secureChannel.notifyVerificationCodeAccepted()
secureChannel.setDeviceIdDuringAssociation(clientId)
verify(spyStorage).saveEncryptionKey(eq(clientId.toString()), any())
verify(mockCallback).onSecureChannelEstablished()
}
@Test
fun association_wrongInitHandshakeMessage_issueInvalidHandshakeError() {
setupSecureChannel(false)
// Wrong init handshake message
respondToContinueMessage()
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_HANDSHAKE)
}
@Test
@Throws(InterruptedException::class)
fun association_wrongRespondToContinueMessage_issueInvalidHandshakeError() {
setupSecureChannel(false)
initHandshakeMessage()
// Wrong respond to continue message
initHandshakeMessage()
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_HANDSHAKE)
}
@Test
fun reconnect_secureChannelEstablishedSuccessfully() {
val clientId = UUID.randomUUID()
whenever(spyStorage.getEncryptionKey(clientId.toString())).thenReturn(byteArrayOf())
setupSecureChannel(true, clientId.toString())
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage()
verify(spyStorage).saveEncryptionKey(eq(clientId.toString()), any())
verify(mockCallback).onSecureChannelEstablished()
}
@Test
fun reconnect_deviceIdNotSet_issueInvalidStateError() {
setupSecureChannel(true)
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage()
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_STATE)
}
@Test
fun reconnect_keyNotSet_issueInvalidKeyError() {
setupSecureChannel(true, UUID.randomUUID().toString())
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage()
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_ENCRYPTION_KEY)
}
@Test
fun processHandshakeResumingSession_incorrectHandshakeState_issueInvalidStateError() {
val clientId = UUID.randomUUID().toString()
setupSecureChannel(true, clientId)
whenever(spyStorage.getEncryptionKey(clientId)).thenReturn(byteArrayOf())
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage(FakeEncryptionRunner.RECONNECTION_MESSAGE_STATE_ERROR)
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_STATE)
}
@Test
fun processHandshakeResumingSession_emptyNewKey_issueInvalidKeyError() {
val clientId = UUID.randomUUID().toString()
setupSecureChannel(true, clientId)
whenever(spyStorage.getEncryptionKey(clientId)).thenReturn(byteArrayOf())
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage(FakeEncryptionRunner.RECONNECTION_MESSAGE_KEY_ERROR)
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_ENCRYPTION_KEY)
}
@Test
fun processHandshakeResumingSession_emptyNextMessage_issueInvalidMessageError() {
val clientId = UUID.randomUUID().toString()
setupSecureChannel(true, clientId)
whenever(spyStorage.getEncryptionKey(clientId)).thenReturn(byteArrayOf())
initHandshakeMessage()
respondToContinueMessage()
respondToResumeMessage(FakeEncryptionRunner.RECONNECTION_MESSAGE_EMPTY_RESPONSE)
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_MSG)
}
@Test
fun processHandshake_invalidState_issueInvalidStateError() {
setupSecureChannel(false)
secureChannel.showVerificationCodeListener = mockShowVerificationCodeListener
initHandshakeMessage()
respondToContinueMessage()
secureChannel.processHandshake(byteArrayOf())
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_STATE)
}
@Test
fun processHandshake_receivedMessageAfterFinish_ignoreMessage() {
setupSecureChannel(false)
secureChannel.showVerificationCodeListener = mockShowVerificationCodeListener
initHandshakeMessage()
respondToContinueMessage()
secureChannel.notifyVerificationCodeAccepted()
secureChannel.processHandshake(byteArrayOf())
verify(mockCallback, never()).onEstablishSecureChannelFailure(any())
}
@Test
fun processHandshakeInitialization_nextMessageIsNull_issueInvalidHandshakeError() {
setupSecureChannel(false)
initHandshakeMessage(FakeEncryptionRunner.INIT_MESSAGE_EMPTY_RESPONSE)
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_HANDSHAKE)
}
@Test
fun processVerificationCode_listenerNotSet_issueInvalidStateError() {
setupSecureChannel(false)
initHandshakeMessage()
respondToContinueMessage()
val testVerificationCodeMessage =
VerificationCode.newBuilder().setState(VerificationCodeState.VISUAL_VERIFICATION).build()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
verify(mockCallback).onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_STATE)
}
@Test
fun setDeviceIdDuringAssociation_encryptionKeyNotSet_issueInvalidKeyError() {
setupSecureChannel(false)
secureChannel.setDeviceIdDuringAssociation(UUID.randomUUID())
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_ENCRYPTION_KEY)
}
@Test
fun notifyVerificationCodeAccepted_invalidState_issueInvalidVerificationError() {
setupSecureChannel(false)
initHandshakeMessage()
secureChannel.notifyVerificationCodeAccepted()
verify(mockCallback)
.onEstablishSecureChannelFailure(ChannelError.CHANNEL_ERROR_INVALID_VERIFICATION)
}
@Test
fun processVerificationCodeMessage_oobVerification_verifyOobCode() {
setupSecureChannel(false)
initHandshakeMessage()
respondToContinueMessage()
val testPayload = "testPayload".toByteArray()
val testVerificationCodeMessage =
VerificationCode.newBuilder().run {
setState(VerificationCodeState.OOB_VERIFICATION)
setPayload(ByteString.copyFrom(testPayload))
build()
}
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
verify(mockOobRunner).decryptData(testPayload)
}
@Test
fun createOobResponse_oobCodeMatch_sendCorrectMessage() {
val testPayload = "testPayload".toByteArray()
whenever(mockOobRunner.decryptData(testPayload))
.thenReturn(FakeEncryptionRunner.VERIFICATION_CODE)
whenever(mockOobRunner.encryptData(FakeEncryptionRunner.VERIFICATION_CODE))
.thenReturn(testPayload)
setupSecureChannel(false)
initHandshakeMessage()
respondToContinueMessage()
val testVerificationCodeMessage =
VerificationCode.newBuilder().run {
setState(VerificationCodeState.OOB_VERIFICATION)
setPayload(ByteString.copyFrom(testPayload))
build()
}
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
val confirmationMessage =
argumentCaptor<DeviceMessage>().run {
verify(stream1, times(2)).sendMessage(capture())
secondValue.message
}
val verificationCodeMessage = VerificationCode.parseFrom(confirmationMessage)
assertThat(verificationCodeMessage.state).isEqualTo(VerificationCodeState.OOB_VERIFICATION)
}
@Test
fun onVisualVerificationCodeConfirmed_sendConfirmationMessage() {
setupSecureChannel(false)
initHandshakeMessage()
respondToContinueMessage()
val testVerificationCodeMessage =
VerificationCode.newBuilder().setState(VerificationCodeState.VISUAL_VERIFICATION).build()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
secureChannel.notifyVerificationCodeAccepted()
val confirmationMessage =
argumentCaptor<DeviceMessage>().run {
verify(stream1, times(2)).sendMessage(capture())
secondValue.message
}
val verificationCodeMessage = VerificationCode.parseFrom(confirmationMessage)
assertThat(verificationCodeMessage.state).isEqualTo(VerificationCodeState.VISUAL_CONFIRMATION)
}
private fun setupSecureChannel(isReconnect: Boolean, deviceId: String? = null) {
val encryptionRunner = EncryptionRunnerFactory.newFakeRunner()
encryptionRunner.setIsReconnect(isReconnect)
secureChannel =
spy(
MultiProtocolSecureChannel(
stream1,
spyStorage,
encryptionRunner,
mockOobRunner,
deviceId = deviceId,
inflater = mockInflater
)
.apply { callback = mockCallback }
)
}
private fun initHandshakeMessage(message: ByteArray = FakeEncryptionRunner.INIT_MESSAGE) {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
message
)
secureChannel.onDeviceMessageReceived(deviceMessage)
}
private fun respondToContinueMessage(message: ByteArray = FakeEncryptionRunner.CLIENT_RESPONSE) {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
message
)
secureChannel.onDeviceMessageReceived(deviceMessage)
}
private fun respondToResumeMessage(message: ByteArray = "Placeholder Message".toByteArray()) {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
message
)
secureChannel.onDeviceMessageReceived(deviceMessage)
}
private fun completeHandshakeAndSaveTheKey() {
setupSecureChannel(false)
secureChannel.showVerificationCodeListener = mockShowVerificationCodeListener
initHandshakeMessage()
respondToContinueMessage()
secureChannel.notifyVerificationCodeAccepted()
}
}
private open class TestProtocol : ConnectionProtocol() {
override fun isDeviceVerificationRequired() = false
override fun startAssociationDiscovery(
name: String,
identifier: ParcelUuid,
callback: IDiscoveryCallback
) {}
override fun startConnectionDiscovery(
id: ParcelUuid,
challenge: ConnectChallenge,
callback: IDiscoveryCallback
) {}
override fun stopAssociationDiscovery() {}
override fun stopConnectionDiscovery(id: ParcelUuid) {}
override fun sendData(protocolId: String, data: ByteArray, callback: IDataSendCallback?) {}
override fun disconnectDevice(protocolId: String) {}
override fun reset() {}
override fun getMaxWriteSize(protocolId: String): Int {
return 0
}
}
private class Base64CryptoHelper : CryptoHelper {
override fun encrypt(value: ByteArray?): String? = Base64.encodeToString(value, Base64.DEFAULT)
override fun decrypt(value: String?): ByteArray? = Base64.decode(value, Base64.DEFAULT)
}