[MS09] Implement isSameNetwork.

Test: Old tests pass, new tests pass too.
Bug: 113554482

Change-Id: I420471853f3fab7725cba7ae500cebdce1912e43
diff --git a/core/java/android/net/ipmemorystore/NetworkAttributes.java b/core/java/android/net/ipmemorystore/NetworkAttributes.java
index b932d21..5397b57 100644
--- a/core/java/android/net/ipmemorystore/NetworkAttributes.java
+++ b/core/java/android/net/ipmemorystore/NetworkAttributes.java
@@ -37,27 +37,57 @@
 public class NetworkAttributes {
     private static final boolean DBG = true;
 
+    // Weight cutoff for grouping. To group, a similarity score is computed with the following
+    // algorithm : if both fields are non-null and equals() then add their assigned weight, else if
+    // both are null then add a portion of their assigned weight (see NULL_MATCH_WEIGHT),
+    // otherwise add nothing.
+    // As a guideline, this should be something like 60~75% of the total weights in this class. The
+    // design states "in essence a reader should imagine that if two important columns don't match,
+    // or one important and several unimportant columns don't match then the two records are
+    // considered a different group".
+    private static final float TOTAL_WEIGHT_CUTOFF = 520.0f;
+    // The portion of the weight that is earned when scoring group-sameness by having both columns
+    // being null. This is because some networks rightfully don't have some attributes (e.g. a
+    // V6-only network won't have an assigned V4 address) and both being null should count for
+    // something, but attributes may also be null just because data is unavailable.
+    private static final float NULL_MATCH_WEIGHT = 0.25f;
+
     // The v4 address that was assigned to this device the last time it joined this network.
     // This typically comes from DHCP but could be something else like static configuration.
     // This does not apply to IPv6.
     // TODO : add a list of v6 prefixes for the v6 case.
     @Nullable
     public final Inet4Address assignedV4Address;
+    private static final float WEIGHT_ASSIGNEDV4ADDR = 300.0f;
 
     // Optionally supplied by the client if it has an opinion on L3 network. For example, this
     // could be a hash of the SSID + security type on WiFi.
     @Nullable
     public final String groupHint;
+    private static final float WEIGHT_GROUPHINT = 300.0f;
 
     // The list of DNS server addresses.
     @Nullable
     public final List<InetAddress> dnsAddresses;
+    private static final float WEIGHT_DNSADDRESSES = 200.0f;
 
     // The mtu on this network.
     @Nullable
     public final Integer mtu;
+    private static final float WEIGHT_MTU = 50.0f;
 
-    NetworkAttributes(
+    // The sum of all weights in this class. Tests ensure that this stays equal to the total of
+    // all weights.
+    /** @hide */
+    @VisibleForTesting
+    public static final float TOTAL_WEIGHT = WEIGHT_ASSIGNEDV4ADDR
+            + WEIGHT_GROUPHINT
+            + WEIGHT_DNSADDRESSES
+            + WEIGHT_MTU;
+
+    /** @hide */
+    @VisibleForTesting
+    public NetworkAttributes(
             @Nullable final Inet4Address assignedV4Address,
             @Nullable final String groupHint,
             @Nullable final List<InetAddress> dnsAddresses,
@@ -126,6 +156,34 @@
         return parcelable;
     }
 
+    private float samenessContribution(final float weight,
+            @Nullable final Object o1, @Nullable final Object o2) {
+        if (null == o1) {
+            return (null == o2) ? weight * NULL_MATCH_WEIGHT : 0f;
+        }
+        return Objects.equals(o1, o2) ? weight : 0f;
+    }
+
+    /** @hide */
+    public float getNetworkGroupSamenessConfidence(@NonNull final NetworkAttributes o) {
+        final float samenessScore =
+                samenessContribution(WEIGHT_ASSIGNEDV4ADDR, assignedV4Address, o.assignedV4Address)
+                + samenessContribution(WEIGHT_GROUPHINT, groupHint, o.groupHint)
+                + samenessContribution(WEIGHT_DNSADDRESSES, dnsAddresses, o.dnsAddresses)
+                + samenessContribution(WEIGHT_MTU, mtu, o.mtu);
+        // The minimum is 0, the max is TOTAL_WEIGHT and should be represented by 1.0, and
+        // TOTAL_WEIGHT_CUTOFF should represent 0.5, but there is no requirement that
+        // TOTAL_WEIGHT_CUTOFF would be half of TOTAL_WEIGHT (indeed, it should not be).
+        // So scale scores under the cutoff between 0 and 0.5, and the scores over the cutoff
+        // between 0.5 and 1.0.
+        if (samenessScore < TOTAL_WEIGHT_CUTOFF) {
+            return samenessScore / (TOTAL_WEIGHT_CUTOFF * 2);
+        } else {
+            return (samenessScore - TOTAL_WEIGHT_CUTOFF) / (TOTAL_WEIGHT - TOTAL_WEIGHT_CUTOFF) / 2
+                    + 0.5f;
+        }
+    }
+
     /** @hide */
     public static class Builder {
         @Nullable
diff --git a/core/java/android/net/ipmemorystore/SameL3NetworkResponse.java b/core/java/android/net/ipmemorystore/SameL3NetworkResponse.java
index d040dcc..291aca8 100644
--- a/core/java/android/net/ipmemorystore/SameL3NetworkResponse.java
+++ b/core/java/android/net/ipmemorystore/SameL3NetworkResponse.java
@@ -91,7 +91,8 @@
         return confidence > 0.5 ? NETWORK_SAME : NETWORK_DIFFERENT;
     }
 
-    SameL3NetworkResponse(@NonNull final String l2Key1, @NonNull final String l2Key2,
+    /** @hide */
+    public SameL3NetworkResponse(@NonNull final String l2Key1, @NonNull final String l2Key2,
             final float confidence) {
         this.l2Key1 = l2Key1;
         this.l2Key2 = l2Key2;
diff --git a/services/ipmemorystore/java/com/android/server/net/ipmemorystore/IpMemoryStoreService.java b/services/ipmemorystore/java/com/android/server/net/ipmemorystore/IpMemoryStoreService.java
index 444b299..c4d1657 100644
--- a/services/ipmemorystore/java/com/android/server/net/ipmemorystore/IpMemoryStoreService.java
+++ b/services/ipmemorystore/java/com/android/server/net/ipmemorystore/IpMemoryStoreService.java
@@ -37,6 +37,7 @@
 import android.net.ipmemorystore.IOnStatusListener;
 import android.net.ipmemorystore.NetworkAttributes;
 import android.net.ipmemorystore.NetworkAttributesParcelable;
+import android.net.ipmemorystore.SameL3NetworkResponse;
 import android.net.ipmemorystore.Status;
 import android.net.ipmemorystore.StatusParcelable;
 import android.net.ipmemorystore.Utils;
@@ -264,9 +265,40 @@
      * Through the listener, a SameL3NetworkResponse containing the answer and confidence.
      */
     @Override
-    public void isSameNetwork(@NonNull final String l2Key1, @NonNull final String l2Key2,
-            @NonNull final IOnSameNetworkResponseListener listener) {
-        // TODO : implement this.
+    public void isSameNetwork(@Nullable final String l2Key1, @Nullable final String l2Key2,
+            @Nullable final IOnSameNetworkResponseListener listener) {
+        if (null == listener) return;
+        mExecutor.execute(() -> {
+            try {
+                if (null == l2Key1 || null == l2Key2) {
+                    listener.onSameNetworkResponse(makeStatus(ERROR_ILLEGAL_ARGUMENT), null);
+                    return;
+                }
+                if (null == mDb) {
+                    listener.onSameNetworkResponse(makeStatus(ERROR_ILLEGAL_ARGUMENT), null);
+                    return;
+                }
+                try {
+                    final NetworkAttributes attr1 =
+                            IpMemoryStoreDatabase.retrieveNetworkAttributes(mDb, l2Key1);
+                    final NetworkAttributes attr2 =
+                            IpMemoryStoreDatabase.retrieveNetworkAttributes(mDb, l2Key2);
+                    if (null == attr1 || null == attr2) {
+                        listener.onSameNetworkResponse(makeStatus(SUCCESS),
+                                new SameL3NetworkResponse(l2Key1, l2Key2,
+                                        -1f /* never connected */).toParcelable());
+                        return;
+                    }
+                    final float confidence = attr1.getNetworkGroupSamenessConfidence(attr2);
+                    listener.onSameNetworkResponse(makeStatus(SUCCESS),
+                            new SameL3NetworkResponse(l2Key1, l2Key2, confidence).toParcelable());
+                } catch (Exception e) {
+                    listener.onSameNetworkResponse(makeStatus(ERROR_GENERIC), null);
+                }
+            } catch (final RemoteException e) {
+                // Client at the other end died
+            }
+        });
     }
 
     /**
diff --git a/tests/net/java/com/android/server/net/ipmemorystore/IpMemoryStoreServiceTest.java b/tests/net/java/com/android/server/net/ipmemorystore/IpMemoryStoreServiceTest.java
index 94bcd28..c58941a 100644
--- a/tests/net/java/com/android/server/net/ipmemorystore/IpMemoryStoreServiceTest.java
+++ b/tests/net/java/com/android/server/net/ipmemorystore/IpMemoryStoreServiceTest.java
@@ -28,9 +28,12 @@
 import android.net.ipmemorystore.Blob;
 import android.net.ipmemorystore.IOnBlobRetrievedListener;
 import android.net.ipmemorystore.IOnNetworkAttributesRetrieved;
+import android.net.ipmemorystore.IOnSameNetworkResponseListener;
 import android.net.ipmemorystore.IOnStatusListener;
 import android.net.ipmemorystore.NetworkAttributes;
 import android.net.ipmemorystore.NetworkAttributesParcelable;
+import android.net.ipmemorystore.SameL3NetworkResponse;
+import android.net.ipmemorystore.SameL3NetworkResponseParcelable;
 import android.net.ipmemorystore.Status;
 import android.net.ipmemorystore.StatusParcelable;
 import android.os.IBinder;
@@ -144,6 +147,28 @@
         };
     }
 
+    /** Helper method to make an IOnSameNetworkResponseListener */
+    private interface OnSameNetworkResponseListener {
+        void onSameNetworkResponse(Status status, SameL3NetworkResponse answer);
+    }
+    private IOnSameNetworkResponseListener onSameResponse(
+            final OnSameNetworkResponseListener functor) {
+        return new IOnSameNetworkResponseListener() {
+            @Override
+            public void onSameNetworkResponse(final StatusParcelable status,
+                    final SameL3NetworkResponseParcelable sameL3Network)
+                    throws RemoteException {
+                functor.onSameNetworkResponse(new Status(status),
+                        null == sameL3Network ? null : new SameL3NetworkResponse(sameL3Network));
+            }
+
+            @Override
+            public IBinder asBinder() {
+                return null;
+            }
+        };
+    }
+
     // Helper method to factorize some boilerplate
     private void doLatched(final String timeoutMessage, final Consumer<CountDownLatch> functor) {
         final CountDownLatch latch = new CountDownLatch(1);
@@ -155,6 +180,19 @@
         }
     }
 
+    // Helper methods to factorize more boilerplate
+    private void storeAttributes(final String l2Key, final NetworkAttributes na) {
+        storeAttributes("Did not complete storing attributes", l2Key, na);
+    }
+    private void storeAttributes(final String timeoutMessage, final String l2Key,
+            final NetworkAttributes na) {
+        doLatched(timeoutMessage, latch -> mService.storeNetworkAttributes(l2Key, na.toParcelable(),
+                onStatus(status -> {
+                    assertTrue("Store not successful : " + status.resultCode, status.isSuccess());
+                    latch.countDown();
+                })));
+    }
+
     @Test
     public void testNetworkAttributes() {
         final NetworkAttributes.Builder na = new NetworkAttributes.Builder();
@@ -166,13 +204,7 @@
         na.setMtu(219);
         final String l2Key = UUID.randomUUID().toString();
         NetworkAttributes attributes = na.build();
-        doLatched("Did not complete storing attributes", latch ->
-                mService.storeNetworkAttributes(l2Key, attributes.toParcelable(),
-                        onStatus(status -> {
-                            assertTrue("Store status not successful : " + status.resultCode,
-                                    status.isSuccess());
-                            latch.countDown();
-                        })));
+        storeAttributes(l2Key, attributes);
 
         doLatched("Did not complete retrieving attributes", latch ->
                 mService.retrieveNetworkAttributes(l2Key, onNetworkAttributesRetrieved(
@@ -190,9 +222,7 @@
                     new InetAddress[] {Inet6Address.getByName("0A1C:2E40:480A::1CA6")}));
         } catch (UnknownHostException e) { /* Still can't happen */ }
         final NetworkAttributes attributes2 = na2.build();
-        doLatched("Did not complete storing attributes 2", latch ->
-                mService.storeNetworkAttributes(l2Key, attributes2.toParcelable(),
-                        onStatus(status -> latch.countDown())));
+        storeAttributes("Did not complete storing attributes 2", l2Key, attributes2);
 
         doLatched("Did not complete retrieving attributes 2", latch ->
                 mService.retrieveNetworkAttributes(l2Key, onNetworkAttributesRetrieved(
@@ -306,8 +336,54 @@
         // TODO : implement this
     }
 
+    private void assertNetworksSameness(final String key1, final String key2, final int sameness) {
+        doLatched("Did not finish evaluating sameness", latch ->
+                mService.isSameNetwork(key1, key2, onSameResponse((status, answer) -> {
+                    assertTrue("Retrieve network sameness not successful : " + status.resultCode,
+                            status.isSuccess());
+                    assertEquals(sameness, answer.getNetworkSameness());
+                })));
+    }
+
     @Test
-    public void testIsSameNetwork() {
-        // TODO : implement this
+    public void testIsSameNetwork() throws UnknownHostException {
+        final NetworkAttributes.Builder na = new NetworkAttributes.Builder();
+        na.setAssignedV4Address((Inet4Address) Inet4Address.getByAddress(new byte[]{1, 2, 3, 4}));
+        na.setGroupHint("hint1");
+        na.setMtu(219);
+        na.setDnsAddresses(Arrays.asList(Inet6Address.getByName("0A1C:2E40:480A::1CA6")));
+
+        final String[] keys = new String[4];
+        for (int i = 0; i < keys.length; ++i) {
+            keys[i] = UUID.randomUUID().toString();
+        }
+        storeAttributes(keys[0], na.build());
+        // 0 and 1 have identical attributes
+        storeAttributes(keys[1], na.build());
+
+        // Hopefully only the MTU being different still means it's the same network
+        na.setMtu(200);
+        storeAttributes(keys[2], na.build());
+
+        // Hopefully different MTU, assigned V4 address and grouphint make a different network,
+        // even with identical DNS addresses
+        na.setAssignedV4Address(null);
+        na.setGroupHint("hint2");
+        storeAttributes(keys[3], na.build());
+
+        assertNetworksSameness(keys[0], keys[1], SameL3NetworkResponse.NETWORK_SAME);
+        assertNetworksSameness(keys[0], keys[2], SameL3NetworkResponse.NETWORK_SAME);
+        assertNetworksSameness(keys[1], keys[2], SameL3NetworkResponse.NETWORK_SAME);
+        assertNetworksSameness(keys[0], keys[3], SameL3NetworkResponse.NETWORK_DIFFERENT);
+        assertNetworksSameness(keys[0], UUID.randomUUID().toString(),
+                SameL3NetworkResponse.NETWORK_NEVER_CONNECTED);
+
+        doLatched("Did not finish evaluating sameness", latch ->
+                mService.isSameNetwork(null, null, onSameResponse((status, answer) -> {
+                    assertFalse("Retrieve network sameness suspiciously successful : "
+                            + status.resultCode, status.isSuccess());
+                    assertEquals(Status.ERROR_ILLEGAL_ARGUMENT, status.resultCode);
+                    assertNull(answer);
+                })));
     }
 }
diff --git a/tests/net/java/com/android/server/net/ipmemorystore/NetworkAttributesTest.java b/tests/net/java/com/android/server/net/ipmemorystore/NetworkAttributesTest.java
new file mode 100644
index 0000000..fe19eee
--- /dev/null
+++ b/tests/net/java/com/android/server/net/ipmemorystore/NetworkAttributesTest.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.net.ipmemorystore;
+
+import static org.junit.Assert.assertEquals;
+
+import android.net.ipmemorystore.NetworkAttributes;
+import android.support.test.filters.SmallTest;
+import android.support.test.runner.AndroidJUnit4;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import java.lang.reflect.Field;
+import java.net.Inet4Address;
+import java.net.UnknownHostException;
+import java.util.Arrays;
+
+/** Unit tests for {@link NetworkAttributes}. */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class NetworkAttributesTest {
+    private static final String WEIGHT_FIELD_NAME_PREFIX = "WEIGHT_";
+    private static final float EPSILON = 0.0001f;
+
+    // This is running two tests to make sure the total weight is the sum of all weights. To be
+    // sure this is not fireproof, but you'd kind of need to do it on purpose to pass.
+    @Test
+    public void testTotalWeight() throws IllegalAccessException, UnknownHostException {
+        // Make sure that TOTAL_WEIGHT is equal to the sum of the fields starting with WEIGHT_
+        float sum = 0f;
+        final Field[] fieldList = NetworkAttributes.class.getDeclaredFields();
+        for (final Field field : fieldList) {
+            if (!field.getName().startsWith(WEIGHT_FIELD_NAME_PREFIX)) continue;
+            field.setAccessible(true);
+            sum += (float) field.get(null);
+        }
+        assertEquals(sum, NetworkAttributes.TOTAL_WEIGHT, EPSILON);
+
+        // Use directly the constructor with all attributes, and make sure that when compared
+        // to itself the score is a clean 1.0f.
+        final NetworkAttributes na =
+                new NetworkAttributes(
+                        (Inet4Address) Inet4Address.getByAddress(new byte[] {1, 2, 3, 4}),
+                        "some hint",
+                        Arrays.asList(Inet4Address.getByAddress(new byte[] {5, 6, 7, 8}),
+                                Inet4Address.getByAddress(new byte[] {9, 0, 1, 2})),
+                        98);
+        assertEquals(1.0f, na.getNetworkGroupSamenessConfidence(na), EPSILON);
+    }
+}