blob: 4f46cfb3fb4d1b15e2a770b0365ec0774f0bc5f4 [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.android.connecteddevice.connection
import android.os.ParcelUuid
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.android.companionprotos.DeviceMessageProto.Message
import com.google.android.companionprotos.OperationProto.OperationType
import com.google.android.companionprotos.PacketProto.Packet
import com.google.android.connecteddevice.connection.ProtocolStream.MessageReceivedListener
import com.google.android.connecteddevice.connection.ProtocolStream.ProtocolDisconnectListener
import com.google.android.connecteddevice.model.DeviceMessage
import com.google.android.connecteddevice.transport.ConnectChallenge
import com.google.android.connecteddevice.transport.ConnectionProtocol
import com.google.android.connecteddevice.transport.IDataSendCallback
import com.google.android.connecteddevice.transport.IDiscoveryCallback
import com.google.android.connecteddevice.transport.ProtocolDevice
import com.google.android.connecteddevice.util.ByteUtils
import com.google.common.truth.Truth
import com.google.common.truth.Truth.assertThat
import com.google.common.util.concurrent.MoreExecutors.directExecutor
import com.google.protobuf.ByteString
import com.google.protobuf.ExtensionRegistryLite
import java.util.UUID
import java.util.concurrent.ThreadLocalRandom
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.spy
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.eq
import org.mockito.kotlin.mock
import org.mockito.kotlin.never
private const val PROTOCOL_ID = "testDevice"
private const val MAX_WRITE_SIZE = 185
@RunWith(AndroidJUnit4::class)
class ProtocolStreamTest {
private val protocol = spy(TestProtocol())
private val stream = ProtocolStream(ProtocolDevice(protocol, PROTOCOL_ID))
@Test
fun sendMessage_smallMessageSendsSinglePacket() {
val recipient = UUID.randomUUID()
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
stream.sendMessage(
DeviceMessage.createOutgoingMessage(
recipient,
/* isMessageEncrypted= */ false,
DeviceMessage.OperationType.CLIENT_MESSAGE,
message
)
)
verify(protocol).sendData(eq(PROTOCOL_ID), any(), any())
}
@Test
fun sendMessage_largeMessageSendsMultiplePackets() {
val recipient = UUID.randomUUID()
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE + 1)
stream.sendMessage(
DeviceMessage.createOutgoingMessage(
recipient,
/* isMessageEncrypted= */ false,
DeviceMessage.OperationType.CLIENT_MESSAGE,
message
)
)
verify(protocol, times(2)).sendData(eq(PROTOCOL_ID), any(), any())
}
@Test
fun sendMessage_serializesPayloadIntoPacket() {
val recipient = UUID.randomUUID()
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
stream.sendMessage(
DeviceMessage.createOutgoingMessage(
recipient,
/* isMessageEncrypted= */ false,
DeviceMessage.OperationType.CLIENT_MESSAGE,
message
)
)
argumentCaptor<ByteArray>().apply {
verify(protocol).sendData(eq(PROTOCOL_ID), capture(), any())
val packet = Packet.parseFrom(firstValue, ExtensionRegistryLite.getEmptyRegistry())
val sentMessage =
Message.parseFrom(packet.payload.toByteArray(), ExtensionRegistryLite.getEmptyRegistry())
.payload
.toByteArray()
assertThat(message).isEqualTo(sentMessage)
}
}
@Test
fun protocolDisconnect_preventsFutureMessagesFromBeingSent() {
protocol.disconnectDevice(PROTOCOL_ID)
val recipient = UUID.randomUUID()
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
stream.sendMessage(
DeviceMessage.createOutgoingMessage(
recipient,
/* isMessageEncrypted= */ false,
DeviceMessage.OperationType.CLIENT_MESSAGE,
message
)
)
verify(protocol, never()).sendData(eq(PROTOCOL_ID), any(), any())
}
@Test
fun protocolDisconnect_invokesProtocolDisconnectListener() {
val listener: ProtocolDisconnectListener = mock()
stream.protocolDisconnectListener = listener
protocol.disconnectDevice(PROTOCOL_ID)
verify(listener).onProtocolDisconnected()
}
@Test
fun protocolOnDataFailedToSend_disconnectsProtocol() {
val failingProtocol = spy(FailingSendProtocol())
val failingStream = ProtocolStream(ProtocolDevice(failingProtocol, PROTOCOL_ID))
val recipient = UUID.randomUUID()
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
failingStream.sendMessage(
DeviceMessage.createOutgoingMessage(
recipient,
/* isMessageEncrypted= */ false,
DeviceMessage.OperationType.CLIENT_MESSAGE,
message
)
)
verify(failingProtocol).disconnectDevice(PROTOCOL_ID)
}
@Test
fun onMessageReceived_invokedWithSinglePacketMessage() {
val listener: MessageReceivedListener = mock()
stream.messageReceivedListener = listener
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
val packets = createPackets(message)
packets.forEach { protocol.receiveData(it.toByteArray()) }
val captor = argumentCaptor<DeviceMessage>()
verify(listener).onMessageReceived(captor.capture())
val deviceMessage = captor.firstValue
assertThat(message).isEqualTo(deviceMessage.message)
assertThat(message.size).isEqualTo(deviceMessage.originalMessageSize)
}
@Test
fun onMessageReceived_invokedWithMultiplePacketMessage() {
val listener: MessageReceivedListener = mock()
stream.messageReceivedListener = listener
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE + 1)
val packets = createPackets(message)
packets.forEach { protocol.receiveData(it.toByteArray()) }
val captor = argumentCaptor<DeviceMessage>()
verify(listener).onMessageReceived(captor.capture())
val deviceMessage = captor.firstValue
assertThat(message).isEqualTo(deviceMessage.message)
assertThat(message.size).isEqualTo(deviceMessage.originalMessageSize)
}
@Test
fun onMessageReceived_parallelMessagesSuccessful() {
val listener: MessageReceivedListener = mock()
stream.messageReceivedListener = listener
val longMessage = ByteUtils.randomBytes((MAX_WRITE_SIZE * 2.5).toInt())
val shortMessage = ByteUtils.randomBytes(MAX_WRITE_SIZE / 2)
val longPackets = createPackets(longMessage)
val shortPackets = createPackets(shortMessage)
protocol.receiveData(longPackets.first().toByteArray())
protocol.receiveData(shortPackets.first().toByteArray())
val captor = argumentCaptor<DeviceMessage>()
verify(listener).onMessageReceived(captor.capture())
assertThat(shortMessage).isEqualTo(captor.firstValue.message)
}
@Test
fun onDataReceived_duplicatePacketIsIgnored() {
val listener: MessageReceivedListener = mock()
stream.messageReceivedListener = listener
val message = ByteUtils.randomBytes(MAX_WRITE_SIZE + 1)
val packets = createPackets(message)
protocol.receiveData(packets.first().toByteArray())
packets.forEach { protocol.receiveData(it.toByteArray()) }
val captor = argumentCaptor<DeviceMessage>()
verify(listener).onMessageReceived(captor.capture())
assertThat(message).isEqualTo(captor.firstValue.message)
}
@Test
fun onDataReceived_packetBeforeExpectedRangeDisconnectsDevice() {
val message = ByteUtils.randomBytes((MAX_WRITE_SIZE * 2.5).toInt())
val packets = createPackets(message)
protocol.receiveData(packets[0].toByteArray())
protocol.receiveData(packets[1].toByteArray())
protocol.receiveData(packets[0].toByteArray())
verify(protocol).disconnectDevice(PROTOCOL_ID)
}
@Test
fun onDataReceived_packetAfterExpectedRangeDisconnectsDevice() {
val message = ByteUtils.randomBytes((MAX_WRITE_SIZE * 2.5).toInt())
val packets = createPackets(message)
protocol.receiveData(packets.first().toByteArray())
protocol.receiveData(packets.last().toByteArray())
verify(protocol).disconnectDevice(PROTOCOL_ID)
}
private fun createPackets(data: ByteArray): List<Packet> {
return try {
val message =
Message.newBuilder()
.setPayload(ByteString.copyFrom(data))
.setOperation(OperationType.CLIENT_MESSAGE)
.setOriginalSize(data.size)
.build()
PacketFactory.makePackets(
message.toByteArray(),
ThreadLocalRandom.current().nextInt(),
MAX_WRITE_SIZE
)
} catch (e: Exception) {
Truth.assertWithMessage("Uncaught exception while making packets.").fail()
ArrayList()
}
}
open class TestProtocol : ConnectionProtocol(directExecutor()) {
override fun isDeviceVerificationRequired() = false
fun receiveData(data: ByteArray) {
dataReceivedListeners[PROTOCOL_ID]?.invoke { it.onDataReceived(PROTOCOL_ID, data) }
}
override fun startAssociationDiscovery(
name: String,
identifier: ParcelUuid,
callback: IDiscoveryCallback
) {}
override fun startConnectionDiscovery(
id: ParcelUuid,
challenge: ConnectChallenge,
callback: IDiscoveryCallback
) {}
override fun stopAssociationDiscovery() {}
override fun stopConnectionDiscovery(id: ParcelUuid) {}
override fun sendData(protocolId: String, data: ByteArray, callback: IDataSendCallback?) {
callback?.onDataSentSuccessfully()
}
override fun disconnectDevice(protocolId: String) {
deviceDisconnectedListeners[protocolId]?.invoke { it.onDeviceDisconnected(protocolId) }
}
override fun reset() {}
override fun getMaxWriteSize(protocolId: String): Int {
return MAX_WRITE_SIZE
}
}
open class FailingSendProtocol : ConnectionProtocol(directExecutor()) {
override fun isDeviceVerificationRequired() = false
override fun startAssociationDiscovery(
name: String,
identifier: ParcelUuid,
callback: IDiscoveryCallback
) {}
override fun startConnectionDiscovery(
id: ParcelUuid,
challenge: ConnectChallenge,
callback: IDiscoveryCallback
) {}
override fun stopAssociationDiscovery() {}
override fun stopConnectionDiscovery(id: ParcelUuid) {}
override fun sendData(protocolId: String, data: ByteArray, callback: IDataSendCallback?) {
callback?.onDataFailedToSend()
}
override fun disconnectDevice(protocolId: String) {}
override fun reset() {}
override fun getMaxWriteSize(protocolId: String): Int {
return MAX_WRITE_SIZE
}
}
}