Merge "Trigger onLinkPropertiesChange callback when link address lifetime update." into main
diff --git a/src/android/net/apf/ApfFilter.java b/src/android/net/apf/ApfFilter.java
index e4b6564..1e45791 100644
--- a/src/android/net/apf/ApfFilter.java
+++ b/src/android/net/apf/ApfFilter.java
@@ -1518,19 +1518,17 @@
         gen.addLoad16(R0, ARP_OPCODE_OFFSET);
         if (mIPv4Address == null) {
             // Drop if ARP REQUEST and we do not have an IPv4 address
-            maybeSetupCounter(gen, Counter.DROPPED_ARP_REQUEST_NO_ADDRESS);
-            gen.addJumpIfR0Equals(ARP_OPCODE_REQUEST, mCountAndDropLabel);
+            gen.addCountAndDropIfR0Equals(ARP_OPCODE_REQUEST,
+                    Counter.DROPPED_ARP_REQUEST_NO_ADDRESS);
         } else {
             gen.addJumpIfR0Equals(ARP_OPCODE_REQUEST, checkTargetIPv4); // Skip to unicast check
         }
         // Drop if unknown ARP opcode.
-        maybeSetupCounter(gen, Counter.DROPPED_ARP_UNKNOWN);
-        gen.addJumpIfR0NotEquals(ARP_OPCODE_REPLY, mCountAndDropLabel);
+        gen.addCountAndDropIfR0NotEquals(ARP_OPCODE_REPLY, Counter.DROPPED_ARP_UNKNOWN);
 
         // Drop if ARP reply source IP is 0.0.0.0
         gen.addLoad32(R0, ARP_SOURCE_IP_ADDRESS_OFFSET);
-        maybeSetupCounter(gen, Counter.DROPPED_ARP_REPLY_SPA_NO_HOST);
-        gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel);
+        gen.addCountAndDropIfR0Equals(IPV4_ANY_HOST_ADDRESS, Counter.DROPPED_ARP_REPLY_SPA_NO_HOST);
 
         // Pass if non-broadcast reply.
         gen.addLoadImmediate(R0, ETH_DEST_ADDR_OFFSET);
@@ -1542,8 +1540,7 @@
         if (mIPv4Address == null) {
             // When there is no IPv4 address, drop GARP replies (b/29404209).
             gen.addLoad32(R0, ARP_TARGET_IP_ADDRESS_OFFSET);
-            maybeSetupCounter(gen, Counter.DROPPED_GARP_REPLY);
-            gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel);
+            gen.addCountAndDropIfR0Equals(IPV4_ANY_HOST_ADDRESS, Counter.DROPPED_GARP_REPLY);
         } else {
             // When there is an IPv4 address, drop unicast/broadcast requests
             // and broadcast replies with a different target IPv4 address.
@@ -1606,17 +1603,15 @@
             // If IPv4 destination address is in multicast range, drop.
             gen.addLoad8(R0, IPV4_DEST_ADDR_OFFSET);
             gen.addAnd(0xf0);
-            maybeSetupCounter(gen, Counter.DROPPED_IPV4_MULTICAST);
-            gen.addJumpIfR0Equals(0xe0, mCountAndDropLabel);
+            gen.addCountAndDropIfR0Equals(0xe0, Counter.DROPPED_IPV4_MULTICAST);
 
             // If IPv4 broadcast packet, drop regardless of L2 (b/30231088).
-            maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_ADDR);
             gen.addLoad32(R0, IPV4_DEST_ADDR_OFFSET);
-            gen.addJumpIfR0Equals(IPV4_BROADCAST_ADDRESS, mCountAndDropLabel);
+            gen.addCountAndDropIfR0Equals(IPV4_BROADCAST_ADDRESS,
+                    Counter.DROPPED_IPV4_BROADCAST_ADDR);
             if (mIPv4Address != null && mIPv4PrefixLength < 31) {
-                maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_NET);
                 int broadcastAddr = ipv4BroadcastAddress(mIPv4Address, mIPv4PrefixLength);
-                gen.addJumpIfR0Equals(broadcastAddr, mCountAndDropLabel);
+                gen.addCountAndDropIfR0Equals(broadcastAddr, Counter.DROPPED_IPV4_BROADCAST_NET);
             }
         }
 
@@ -1708,8 +1703,7 @@
 
         // MLD packets set the router-alert hop-by-hop option.
         // TODO: be smarter about not blindly passing every packet with HBH options.
-        maybeSetupCounter(gen, Counter.PASSED_MLD);
-        gen.addJumpIfR0Equals(IPPROTO_HOPOPTS, mCountAndPassLabel);
+        gen.addCountAndPassIfR0Equals(IPPROTO_HOPOPTS, Counter.PASSED_MLD);
 
         // Drop multicast if the multicast filter is enabled.
         if (mMulticastFilter) {
@@ -1732,9 +1726,8 @@
 
             // Drop all other packets sent to ff00::/8 (multicast prefix).
             gen.defineLabel(dropAllIPv6MulticastsLabel);
-            maybeSetupCounter(gen, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST);
             gen.addLoad8(R0, IPV6_DEST_ADDR_OFFSET);
-            gen.addJumpIfR0Equals(0xff, mCountAndDropLabel);
+            gen.addCountAndDropIfR0Equals(0xff, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST);
             // If any keepalive filter matches, drop
             generateV6KeepaliveFilters(gen);
             // Not multicast. Pass.
@@ -1743,8 +1736,7 @@
         } else {
             generateV6KeepaliveFilters(gen);
             // If not ICMPv6, pass.
-            maybeSetupCounter(gen, Counter.PASSED_IPV6_NON_ICMP);
-            gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6, mCountAndPassLabel);
+            gen.addCountAndPassIfR0NotEquals(IPPROTO_ICMPV6, Counter.PASSED_IPV6_NON_ICMP);
         }
 
         // If we got this far, the packet is ICMPv6.  Drop some specific types.
@@ -1753,8 +1745,8 @@
         String skipUnsolicitedMulticastNALabel = "skipUnsolicitedMulticastNA";
         gen.addLoad8(R0, ICMP6_TYPE_OFFSET);
         // Drop all router solicitations (b/32833400)
-        maybeSetupCounter(gen, Counter.DROPPED_IPV6_ROUTER_SOLICITATION);
-        gen.addJumpIfR0Equals(ICMPV6_ROUTER_SOLICITATION, mCountAndDropLabel);
+        gen.addCountAndDropIfR0Equals(ICMPV6_ROUTER_SOLICITATION,
+                Counter.DROPPED_IPV6_ROUTER_SOLICITATION);
         // If not neighbor announcements, skip filter.
         gen.addJumpIfR0NotEquals(ICMPV6_NEIGHBOR_ADVERTISEMENT, skipUnsolicitedMulticastNALabel);
         // Drop all multicast NA to ff02::/120.
@@ -1999,14 +1991,13 @@
 
         if (mDrop802_3Frames) {
             // drop 802.3 frames (ethtype < 0x0600)
-            maybeSetupCounter(gen, Counter.DROPPED_802_3_FRAME);
-            gen.addJumpIfR0LessThan(ETH_TYPE_MIN, mCountAndDropLabel);
+            gen.addCountAndDropIfR0LessThan(ETH_TYPE_MIN, Counter.DROPPED_802_3_FRAME);
         }
 
         // Handle ether-type black list
-        maybeSetupCounter(gen, Counter.DROPPED_ETHERTYPE_DENYLISTED);
         for (int p : mEthTypeBlackList) {
-            gen.addJumpIfR0Equals(p, mCountAndDropLabel);
+            // TODO: Refactorings increased APFv4 code size; optimize for reduction.
+            gen.addCountAndDropIfR0Equals(p, Counter.DROPPED_ETHERTYPE_DENYLISTED);
         }
 
         // Add ARP filters:
diff --git a/src/android/net/apf/ApfV4Generator.java b/src/android/net/apf/ApfV4Generator.java
index 0bd5c43..f4f3fd4 100644
--- a/src/android/net/apf/ApfV4Generator.java
+++ b/src/android/net/apf/ApfV4Generator.java
@@ -69,8 +69,7 @@
      */
     @Override
     public ApfV4Generator addCountAndPass(ApfCounterTracker.Counter counter) {
-        if (mVersion >= 4)  addLoadImmediate(R1, counter.offset());
-        return addJump(mCountAndPassLabel);
+        return maybeAddLoadR1CounterOffset(counter).addJump(mCountAndPassLabel);
     }
 
     /**
@@ -83,8 +82,43 @@
      */
     @Override
     public ApfV4Generator addCountAndDrop(ApfCounterTracker.Counter counter) {
-        if (mVersion >= 4) addLoadImmediate(R1, counter.offset());
-        return addJump(mCountAndDropLabel);
+        return maybeAddLoadR1CounterOffset(counter).addJump(mCountAndDropLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt) {
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0Equals(val, mCountAndDropLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt) {
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0Equals(val, mCountAndPassLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) {
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0NotEquals(val, mCountAndDropLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) {
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0NotEquals(val, mCountAndPassLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt) {
+        if (val <= 0) {
+            throw new IllegalArgumentException("val must > 0, current val: " + val);
+        }
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0LessThan(val, mCountAndDropLabel);
+    }
+
+    @Override
+    public ApfV4Generator addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt) {
+        if (val <= 0) {
+            throw new IllegalArgumentException("val must > 0, current val: " + val);
+        }
+        return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0LessThan(val, mCountAndPassLabel);
     }
 
     /**
@@ -109,4 +143,9 @@
                 .addStoreData(R0, 0) // *(R1 + 0) = R0
                 .addJump(DROP_LABEL);
     }
+
+    private ApfV4Generator maybeAddLoadR1CounterOffset(ApfCounterTracker.Counter counter) {
+        if (mVersion >= 4) return addLoadImmediate(R1, counter.offset());
+        return self();
+    }
 }
diff --git a/src/android/net/apf/ApfV4GeneratorBase.java b/src/android/net/apf/ApfV4GeneratorBase.java
index 886581b..f4519a0 100644
--- a/src/android/net/apf/ApfV4GeneratorBase.java
+++ b/src/android/net/apf/ApfV4GeneratorBase.java
@@ -282,6 +282,22 @@
     }
 
     /**
+     * Add instructions to the end of the program to increase counter and drop packet if R0 equals
+     * {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and pass packet if R0 equals
+     * {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
      * Add an instruction to the end of the program to jump to {@code target} if register R0's
      * value does not equal {@code value}.
      */
@@ -290,6 +306,22 @@
     }
 
     /**
+     * Add instructions to the end of the program to increase counter and drop packet if R0 not
+     * equals {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and pass packet if R0 not
+     * equals {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
      * Add an instruction to the end of the program to jump to {@code target} if register R0's
      * value is greater than {@code value}.
      */
@@ -306,6 +338,22 @@
     }
 
     /**
+     * Add instructions to the end of the program to increase counter and drop packet if R0 less
+     * than {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and pass packet if R0 less
+     * than {@code val}
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException;
+
+    /**
      * Add an instruction to the end of the program to jump to {@code target} if register R0's
      * value has any bits set that are also set in {@code value}.
      */
diff --git a/src/android/net/apf/ApfV6Generator.java b/src/android/net/apf/ApfV6Generator.java
index 245d47d..d1349ac 100644
--- a/src/android/net/apf/ApfV6Generator.java
+++ b/src/android/net/apf/ApfV6Generator.java
@@ -62,6 +62,63 @@
         return addCountAndDrop(counter.value());
     }
 
+    @Override
+    public ApfV6Generator addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0NotEquals(val, tgt).addCountAndDrop(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public ApfV6Generator addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0NotEquals(val, tgt).addCountAndPass(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public ApfV6Generator addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0Equals(val, tgt).addCountAndDrop(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public ApfV6Generator addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0Equals(val, tgt).addCountAndPass(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public ApfV6Generator addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        if (val <= 0) {
+            throw new IllegalArgumentException("val must > 0, current val: " + val);
+        }
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0GreaterThan(val - 1, tgt).addCountAndDrop(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public ApfV6Generator addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt)
+            throws IllegalInstructionException {
+        if (val <= 0) {
+            throw new IllegalArgumentException("val must > 0, current val: " + val);
+        }
+        final String tgt = getUniqueLabel();
+        return addJumpIfR0GreaterThan(val - 1, tgt).addCountAndPass(cnt).defineLabel(tgt);
+    }
+
+    private int mLabelCount = 0;
+
+    /**
+     * Return a unique label string.
+     */
+    private String getUniqueLabel() {
+        return "LABEL_" + mLabelCount++;
+    }
+
     /**
      * This method is noop in APFv6.
      */
diff --git a/src/android/net/apf/BaseApfGenerator.java b/src/android/net/apf/BaseApfGenerator.java
index 76b02bf..5fc89dd 100644
--- a/src/android/net/apf/BaseApfGenerator.java
+++ b/src/android/net/apf/BaseApfGenerator.java
@@ -609,18 +609,15 @@
             // If we already know the size the length field, just use it
             switch (mLenFieldOverride) {
                 case -1:
-                    break;
+                    return maxSize;
                 case 1:
-                    return 1;
                 case 2:
-                    return 2;
                 case 4:
-                    return 3;
+                    return mLenFieldOverride;
                 default:
                     throw new IllegalStateException(
                             "mLenFieldOverride has invalid value: " + mLenFieldOverride);
             }
-            return maxSize;
         }
 
         private int calculateTargetLabelOffset() throws IllegalInstructionException {
diff --git a/tests/unit/src/android/net/apf/ApfV5Test.kt b/tests/unit/src/android/net/apf/ApfV5Test.kt
index 6e3b93e..cd5af52 100644
--- a/tests/unit/src/android/net/apf/ApfV5Test.kt
+++ b/tests/unit/src/android/net/apf/ApfV5Test.kt
@@ -656,6 +656,101 @@
     }
 
     @Test
+    fun testCountAndPassDropCompareR0() {
+        doTestCountAndPassDropCompareR0(
+                { mutableMapOf() },
+                { ApfV4Generator(APF_VERSION_4) }
+        )
+        doTestCountAndPassDropCompareR0(
+                { mutableMapOf(Counter.TOTAL_PACKETS to 1) },
+                { ApfV6Generator().addData(byteArrayOf()) }
+        )
+    }
+
+    private fun doTestCountAndPassDropCompareR0(
+            getInitialMap: () -> MutableMap<Counter, Long>,
+            getGenerator: () -> ApfV4GeneratorBase<*>
+    ) {
+        var program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndDropIfR0Equals(123, Counter.DROPPED_ETH_BROADCAST)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        var dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion)
+        var counterMap = decodeCountersIntoMap(dataRegion)
+        var expectedMap = getInitialMap()
+        expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndPassIfR0Equals(123, Counter.PASSED_ARP)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[Counter.PASSED_ARP] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndDropIfR0NotEquals(124, Counter.DROPPED_ETH_BROADCAST)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndPassIfR0NotEquals(124, Counter.PASSED_ARP)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[Counter.PASSED_ARP] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndDropIfR0LessThan(124, Counter.DROPPED_ETH_BROADCAST)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndPassIfR0LessThan(124, Counter.PASSED_ARP)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[Counter.PASSED_ARP] = 1
+        assertEquals(expectedMap, counterMap)
+    }
+
+    @Test
     fun testV4CountAndPassDrop() {
         var program = ApfV4Generator(APF_VERSION_4)
                 .addCountAndDrop(Counter.DROPPED_ETH_BROADCAST)