blob: f0a275a3c424d7f353cc1e2ff6a86c2ba3f44980 [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.android.connecteddevice.connection
import android.content.Context
import android.os.ParcelUuid
import android.util.Base64
import androidx.room.Room
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.android.companionprotos.VerificationCode
import com.google.android.companionprotos.VerificationCodeState
import com.google.android.connecteddevice.connection.MultiProtocolSecureChannel.ShowVerificationCodeListener
import com.google.android.connecteddevice.model.DeviceMessage
import com.google.android.connecteddevice.model.DeviceMessage.OperationType
import com.google.android.connecteddevice.oob.OobRunner
import com.google.android.connecteddevice.storage.ConnectedDeviceDatabase
import com.google.android.connecteddevice.storage.ConnectedDeviceStorage
import com.google.android.connecteddevice.storage.CryptoHelper
import com.google.android.connecteddevice.transport.ConnectChallenge
import com.google.android.connecteddevice.transport.ConnectionProtocol
import com.google.android.connecteddevice.transport.IDataSendCallback
import com.google.android.connecteddevice.transport.IDiscoveryCallback
import com.google.android.connecteddevice.transport.ProtocolDevice
import com.google.android.encryptionrunner.EncryptionRunnerFactory
import com.google.android.encryptionrunner.FakeEncryptionRunner
import com.google.common.truth.Truth.assertThat
import com.google.common.util.concurrent.MoreExecutors.directExecutor
import com.google.protobuf.ByteString
import com.nhaarman.mockitokotlin2.any
import com.nhaarman.mockitokotlin2.argumentCaptor
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.spy
import com.nhaarman.mockitokotlin2.times
import com.nhaarman.mockitokotlin2.verify
import com.nhaarman.mockitokotlin2.whenever
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
private const val PROTOCOL_ID = "testProtocol"
@RunWith(AndroidJUnit4::class)
class MultiProtocolSecureChannelV4Test {
private val context = ApplicationProvider.getApplicationContext<Context>()
private val stream1 = spy(ProtocolStream(ProtocolDevice(TestProtocolV4(), PROTOCOL_ID)))
private lateinit var storage: ConnectedDeviceStorage
private val mockOobRunner: OobRunner = mock()
private val mockShowVerificationCodeListener: ShowVerificationCodeListener = mock()
private val mockCallback: MultiProtocolSecureChannel.Callback = mock()
@Before
fun setUp() {
val database =
Room.inMemoryDatabaseBuilder(context, ConnectedDeviceDatabase::class.java)
.allowMainThreadQueries()
.setQueryExecutor(directExecutor())
.build()
.associatedDeviceDao()
storage = ConnectedDeviceStorage(context, Base64CryptoHelperV4(), database, directExecutor())
}
@Test
fun processVerificationCodeMessage_oobVerification_verifyOobCode() {
val secureChannel = setupSecureChannel(false)
initHandshakeMessage(secureChannel)
respondToContinueMessage(secureChannel)
val testPayload = "testPayload".toByteArray()
val testVerificationCodeMessage =
VerificationCode.newBuilder().run {
state = VerificationCodeState.OOB_VERIFICATION
payload = ByteString.copyFrom(testPayload)
build()
}
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
verify(mockOobRunner).decryptData(testPayload)
}
@Test
fun createOobResponse_oobCodeMatch_sendCorrectMessage() {
val testPayload = "testPayload".toByteArray()
whenever(mockOobRunner.decryptData(testPayload))
.thenReturn(FakeEncryptionRunner.VERIFICATION_CODE)
whenever(mockOobRunner.encryptData(FakeEncryptionRunner.VERIFICATION_CODE))
.thenReturn(testPayload)
val secureChannel = setupSecureChannel(false)
initHandshakeMessage(secureChannel)
respondToContinueMessage(secureChannel)
val testVerificationCodeMessage =
VerificationCode.newBuilder().run {
state = VerificationCodeState.OOB_VERIFICATION
payload = ByteString.copyFrom(testPayload)
build()
}
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
val confirmationMessage =
argumentCaptor<DeviceMessage>().run {
verify(stream1, times(2)).sendMessage(capture())
secondValue.message
}
val verificationCodeMessage = VerificationCode.parseFrom(confirmationMessage)
assertThat(verificationCodeMessage.state).isEqualTo(VerificationCodeState.OOB_VERIFICATION)
}
@Test
fun processVerificationCodeMessage_visualVerification_invokeListener() {
val secureChannel = setupSecureChannel(false)
initHandshakeMessage(secureChannel)
respondToContinueMessage(secureChannel)
val testVerificationCodeMessage =
VerificationCode.newBuilder().setState(VerificationCodeState.VISUAL_VERIFICATION).build()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
verify(mockShowVerificationCodeListener).showVerificationCode(any())
}
@Test
fun onVisualVerificationCodeConfirmed_sendConfirmationMessage() {
val secureChannel = setupSecureChannel(false)
initHandshakeMessage(secureChannel)
respondToContinueMessage(secureChannel)
val testVerificationCodeMessage =
VerificationCode.newBuilder().setState(VerificationCodeState.VISUAL_VERIFICATION).build()
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
testVerificationCodeMessage.toByteArray()
)
secureChannel.onDeviceMessageReceived(deviceMessage)
secureChannel.notifyVerificationCodeAccepted()
val confirmationMessage =
argumentCaptor<DeviceMessage>().run {
verify(stream1, times(2)).sendMessage(capture())
secondValue.message
}
val verificationCodeMessage = VerificationCode.parseFrom(confirmationMessage)
assertThat(verificationCodeMessage.state).isEqualTo(VerificationCodeState.VISUAL_CONFIRMATION)
}
private fun setupSecureChannel(
isReconnect: Boolean,
deviceId: String? = null
): MultiProtocolSecureChannelV4 {
val encryptionRunner = EncryptionRunnerFactory.newFakeRunner()
encryptionRunner.setIsReconnect(isReconnect)
return MultiProtocolSecureChannelV4(
stream1,
storage,
encryptionRunner,
deviceId = deviceId,
oobRunner = mockOobRunner
)
.apply {
callback = mockCallback
showVerificationCodeListener = mockShowVerificationCodeListener
}
}
private fun initHandshakeMessage(
channel: MultiProtocolSecureChannelV4,
message: ByteArray = FakeEncryptionRunner.INIT_MESSAGE
) {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
message
)
channel.onDeviceMessageReceived(deviceMessage)
}
private fun respondToContinueMessage(
channel: MultiProtocolSecureChannelV4,
message: ByteArray = FakeEncryptionRunner.CLIENT_RESPONSE
) {
val deviceMessage =
DeviceMessage.createOutgoingMessage(
/* recipient= */ null,
/* isMessageEncrypted= */ false,
OperationType.ENCRYPTION_HANDSHAKE,
message
)
channel.onDeviceMessageReceived(deviceMessage)
}
}
private class TestProtocolV4 : ConnectionProtocol() {
override fun isDeviceVerificationRequired() = false
override fun startAssociationDiscovery(
name: String,
identifier: ParcelUuid,
callback: IDiscoveryCallback
) {}
override fun startConnectionDiscovery(
id: ParcelUuid,
challenge: ConnectChallenge,
callback: IDiscoveryCallback
) {}
override fun stopAssociationDiscovery() {}
override fun stopConnectionDiscovery(id: ParcelUuid) {}
override fun sendData(protocolId: String, data: ByteArray, callback: IDataSendCallback?) {}
override fun disconnectDevice(protocolId: String) {}
override fun reset() {}
override fun getMaxWriteSize(protocolId: String): Int {
return 0
}
}
private class Base64CryptoHelperV4 : CryptoHelper {
override fun encrypt(value: ByteArray?): String? = Base64.encodeToString(value, Base64.DEFAULT)
override fun decrypt(value: String?): ByteArray? = Base64.decode(value, Base64.DEFAULT)
}