Merge "Generate fallback speed label in AccessPoint.java"
diff --git a/packages/SettingsLib/src/com/android/settingslib/wifi/AccessPoint.java b/packages/SettingsLib/src/com/android/settingslib/wifi/AccessPoint.java
index 422690a..6b7e177 100644
--- a/packages/SettingsLib/src/com/android/settingslib/wifi/AccessPoint.java
+++ b/packages/SettingsLib/src/com/android/settingslib/wifi/AccessPoint.java
@@ -63,6 +63,7 @@
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -124,8 +125,14 @@
     private final ConcurrentHashMap<String, ScanResult> mScanResultCache =
             new ConcurrentHashMap<String, ScanResult>(32);
 
-    /** Map of BSSIDs to speed values for individual ScanResults. */
-    private final Map<String, Integer> mScanResultScores = new HashMap<>();
+    /**
+     * Map of BSSIDs to scored networks for individual bssids.
+     *
+     * <p>This cache should not be evicted with scan results, as the values here are used to
+     * generate a fallback in the absence of scores for the visible APs.
+     */
+    // TODO(b/63073866): change this to have score eviction logic
+    private final Map<String, ScoredNetwork> mScoredNetworkCache = new HashMap<>();
 
     /** Maximum age of scan results to hold onto while actively scanning. **/
     private static final long MAX_SCAN_RESULT_AGE_MS = 15000;
@@ -138,6 +145,7 @@
     static final String KEY_SPEED = "key_speed";
     static final String KEY_PSKTYPE = "key_psktype";
     static final String KEY_SCANRESULTCACHE = "key_scanresultcache";
+    static final String KEY_SCOREDNETWORKCACHE = "key_scorednetworkcache";
     static final String KEY_CONFIG = "key_config";
     static final String KEY_FQDN = "key_fqdn";
     static final String KEY_PROVIDER_FRIENDLY_NAME = "key_provider_friendly_name";
@@ -188,7 +196,7 @@
 
     private Object mTag;
 
-    private int mSpeed = Speed.NONE;
+    @Speed private int mSpeed = Speed.NONE;
     private boolean mIsScoredNetworkMetered = false;
 
     // used to co-relate internal vs returned accesspoint.
@@ -238,6 +246,13 @@
                 mScanResultCache.put(result.BSSID, result);
             }
         }
+        if (savedState.containsKey(KEY_SCOREDNETWORKCACHE)) {
+            ArrayList<ScoredNetwork> scoredNetworkArrayList =
+                    savedState.getParcelableArrayList(KEY_SCOREDNETWORKCACHE);
+            for (ScoredNetwork score : scoredNetworkArrayList) {
+                mScoredNetworkCache.put(score.networkKey.wifiKey.bssid, score);
+            }
+        }
         if (savedState.containsKey(KEY_FQDN)) {
             mFqdn = savedState.getString(KEY_FQDN);
         }
@@ -308,8 +323,8 @@
         this.mNetworkInfo = that.mNetworkInfo;
         this.mScanResultCache.clear();
         this.mScanResultCache.putAll(that.mScanResultCache);
-        this.mScanResultScores.clear();
-        this.mScanResultScores.putAll(that.mScanResultScores);
+        this.mScoredNetworkCache.clear();
+        this.mScoredNetworkCache.putAll(that.mScoredNetworkCache);
         this.mId = that.mId;
         this.mSpeed = that.mSpeed;
         this.mIsScoredNetworkMetered = that.mIsScoredNetworkMetered;
@@ -347,7 +362,7 @@
         if (isSaved() && !other.isSaved()) return -1;
         if (!isSaved() && other.isSaved()) return 1;
 
-        // Faster speeds go before slower speeds
+        // Faster speeds go before slower speeds - but only if visible change in speed label
         if (getSpeed() != other.getSpeed()) {
             return other.getSpeed() - getSpeed();
         }
@@ -425,7 +440,6 @@
      */
     boolean update(WifiNetworkScoreCache scoreCache, boolean scoringUiEnabled) {
         boolean scoreChanged = false;
-        mScanResultScores.clear();
         if (scoringUiEnabled) {
             scoreChanged = updateScores(scoreCache);
         }
@@ -435,37 +449,80 @@
     /**
      * Updates the AccessPoint rankingScore and speed, returning true if the data has changed.
      *
+     * <p>Precondition: {@link #mRssi} is up to date before invoking this method.
+     *
      * @param scoreCache The score cache to use to retrieve scores.
+     * @return true if the set speed has changed
      */
     private boolean updateScores(WifiNetworkScoreCache scoreCache) {
-        int oldSpeed = mSpeed;
-        mSpeed = Speed.NONE;
-
         for (ScanResult result : mScanResultCache.values()) {
             ScoredNetwork score = scoreCache.getScoredNetwork(result);
             if (score == null) {
                 continue;
             }
-
-            int speed = score.calculateBadge(result.level);
-            mScanResultScores.put(result.BSSID, speed);
-            mSpeed = Math.max(mSpeed, speed);
+            mScoredNetworkCache.put(result.BSSID, score);
         }
 
-        // set mSpeed to the connected ScanResult if the AccessPoint is the active network
+        return updateSpeed();
+    }
+
+    /**
+     * Updates the internal speed, returning true if the update resulted in a speed label change.
+     */
+    private boolean updateSpeed() {
+        int oldSpeed = mSpeed;
+        mSpeed = generateAverageSpeedForSsid();
+
+        // set speed to the connected ScanResult if the AccessPoint is the active network
         if (isActive() && mInfo != null) {
-            NetworkKey key = NetworkKey.createFromWifiInfo(mInfo);
-            ScoredNetwork score = scoreCache.getScoredNetwork(key);
+            ScoredNetwork score = mScoredNetworkCache.get(mInfo.getBSSID());
             if (score != null) {
-                mSpeed = score.calculateBadge(mInfo.getRssi());
+                if (Log.isLoggable(TAG, Log.DEBUG)) {
+                    Log.d(TAG, "Set score using specific access point curve for connected AP: "
+                            + getSsidStr());
+                }
+                // TODO(b/63073866): Map using getLevel rather than specific rssi value so score
+                // doesn't change without a visible wifi bar change.
+                int speed = score.calculateBadge(mInfo.getRssi());
+                if (speed != Speed.NONE) {
+                    mSpeed = speed;
+                }
             }
         }
 
-        if(WifiTracker.sVerboseLogging) {
+        boolean changed = oldSpeed != mSpeed;
+        if(WifiTracker.sVerboseLogging && changed) {
             Log.i(TAG, String.format("%s: Set speed to %d", ssid, mSpeed));
         }
+        return changed;
+    }
 
-        return oldSpeed != mSpeed;
+    /** Creates a speed value for the current {@link #mRssi} by averaging all non zero badges. */
+    @Speed private int generateAverageSpeedForSsid() {
+        if (mScoredNetworkCache.isEmpty()) {
+            return Speed.NONE;
+        }
+
+        int count = 0;
+        int totalSpeed = 0;
+
+        if (Log.isLoggable(TAG, Log.DEBUG)) {
+            Log.d(TAG, String.format("Generating fallbackspeed for %s using cache: %s",
+                    getSsidStr(), mScoredNetworkCache));
+        }
+
+        for (ScoredNetwork score : mScoredNetworkCache.values()) {
+            int speed = score.calculateBadge(mRssi);
+            if (speed != Speed.NONE) {
+                count++;
+                totalSpeed += speed;
+            }
+        }
+        int speed = count == 0 ? Speed.NONE : totalSpeed / count;
+        if (WifiTracker.sVerboseLogging) {
+            Log.i(TAG, String.format("%s generated fallback speed is: %d", getSsidStr(), speed));
+        }
+        return roundToClosestSpeedEnum(speed);
     }
 
     /**
@@ -580,8 +637,6 @@
 
     /** Updates {@link #mSeen} based on the scan result cache. */
     private void updateSeen() {
-        // TODO(sghuman): Set to now if connected
-
         long seen = 0;
         for (ScanResult result : mScanResultCache.values()) {
             if (result.timestamp > seen) {
@@ -940,17 +995,23 @@
         }
         stringBuilder.append("=").append(result.frequency);
         stringBuilder.append(",").append(result.level);
-        if (hasSpeed(result)) {
+        int speed = getSpecificApSpeed(result);
+        if (speed != Speed.NONE) {
             stringBuilder.append(",")
-                    .append(getSpeedLabel(mScanResultScores.get(result.BSSID)));
+                    .append(getSpeedLabel(speed));
         }
         stringBuilder.append("}");
         return stringBuilder.toString();
     }
 
-    private boolean hasSpeed(ScanResult result) {
-        return mScanResultScores.containsKey(result.BSSID)
-                && mScanResultScores.get(result.BSSID) != Speed.NONE;
+    @Speed private int getSpecificApSpeed(ScanResult result) {
+        ScoredNetwork score = mScoredNetworkCache.get(result.BSSID);
+        if (score == null) {
+            return Speed.NONE;
+        }
+        // For debugging purposes we may want to use mRssi rather than result.level as the average
+        // speed wil be determined by mRssi
+        return score.calculateBadge(result.level);
     }
 
     /**
@@ -1065,6 +1126,8 @@
         evictOldScanResults();
         savedState.putParcelableArrayList(KEY_SCANRESULTCACHE,
                 new ArrayList<ScanResult>(mScanResultCache.values()));
+        savedState.putParcelableArrayList(KEY_SCOREDNETWORKCACHE,
+                new ArrayList<>(mScoredNetworkCache.values()));
         if (mNetworkInfo != null) {
             savedState.putParcelable(KEY_NETWORKINFO, mNetworkInfo);
         }
@@ -1103,8 +1166,12 @@
             updateRssi();
             int newLevel = getLevel();
 
-            if (newLevel > 0 && newLevel != oldLevel && mAccessPointListener != null) {
-                mAccessPointListener.onLevelChanged(this);
+            if (newLevel > 0 && newLevel != oldLevel) {
+                // Only update labels on visible rssi changes
+                updateSpeed();
+                if (mAccessPointListener != null) {
+                    mAccessPointListener.onLevelChanged(this);
+                }
             }
             // This flag only comes from scans, is not easily saved in config
             if (security == SECURITY_PSK) {
@@ -1189,7 +1256,23 @@
     }
 
     @Nullable
-    private String getSpeedLabel(int speed) {
+    @Speed
+    private int roundToClosestSpeedEnum(int speed) {
+        if (speed < Speed.SLOW) {
+            return Speed.NONE;
+        } else if (speed < (Speed.SLOW + Speed.MODERATE) / 2) {
+            return Speed.SLOW;
+        } else if (speed < (Speed.MODERATE + Speed.FAST) / 2) {
+            return Speed.MODERATE;
+        } else if (speed < (Speed.FAST + Speed.VERY_FAST) / 2) {
+            return Speed.FAST;
+        } else {
+            return Speed.VERY_FAST;
+        }
+    }
+
+    @Nullable
+    private String getSpeedLabel(@Speed int speed) {
         switch (speed) {
             case Speed.VERY_FAST:
                 return mContext.getString(R.string.speed_label_very_fast);
diff --git a/packages/SettingsLib/tests/integ/src/com/android/settingslib/wifi/AccessPointTest.java b/packages/SettingsLib/tests/integ/src/com/android/settingslib/wifi/AccessPointTest.java
index ae59d37..083d0c5 100644
--- a/packages/SettingsLib/tests/integ/src/com/android/settingslib/wifi/AccessPointTest.java
+++ b/packages/SettingsLib/tests/integ/src/com/android/settingslib/wifi/AccessPointTest.java
@@ -66,10 +66,24 @@
 public class AccessPointTest {
 
     private static final String TEST_SSID = "test_ssid";
+    private static final int NUM_SCAN_RESULTS = 5;
+
+    private static final ArrayList<ScanResult> SCAN_RESULTS = buildScanResultCache();
+
     private Context mContext;
     @Mock private RssiCurve mockBadgeCurve;
     @Mock private WifiNetworkScoreCache mockWifiNetworkScoreCache;
 
+    private static ScanResult createScanResult(String ssid, String bssid, int rssi) {
+        ScanResult scanResult = new ScanResult();
+        scanResult.SSID = ssid;
+        scanResult.level = rssi;
+        scanResult.BSSID = bssid;
+        scanResult.timestamp = SystemClock.elapsedRealtime() * 1000;
+        scanResult.capabilities = "";
+        return scanResult;
+    }
+
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
@@ -400,7 +414,7 @@
     }
 
     @Test
-    public void testSpeedLabel_isDerivedFromConnectedBssid() {
+    public void testSpeedLabel_isDerivedFromConnectedBssidWhenScoreAvailable() {
         int rssi = -55;
         String bssid = "00:00:00:00:00:00";
         int networkId = 123;
@@ -411,24 +425,42 @@
         info.setBSSID(bssid);
         info.setNetworkId(networkId);
 
+        ArrayList<ScanResult> scanResults = new ArrayList<>();
+        ScanResult scanResultUnconnected = createScanResult(TEST_SSID, "11:11:11:11:11:11", rssi);
+        scanResults.add(scanResultUnconnected);
+
+        ScanResult scanResultConnected = createScanResult(TEST_SSID, bssid, rssi);
+        scanResults.add(scanResultConnected);
+
         AccessPoint ap =
                 new TestAccessPointBuilder(mContext)
                         .setActive(true)
                         .setNetworkId(networkId)
                         .setSsid(TEST_SSID)
-                        .setScanResultCache(buildScanResultCache())
+                        .setScanResultCache(scanResults)
                         .setWifiInfo(info)
                         .build();
 
-        NetworkKey key = new NetworkKey(new WifiKey('"' + TEST_SSID + '"', bssid));
-        when(mockWifiNetworkScoreCache.getScoredNetwork(key))
+        when(mockWifiNetworkScoreCache.getScoredNetwork(scanResultUnconnected))
                 .thenReturn(buildScoredNetworkWithMockBadgeCurve());
-        when(mockBadgeCurve.lookupScore(anyInt())).thenReturn((byte) AccessPoint.Speed.FAST);
+        when(mockBadgeCurve.lookupScore(anyInt())).thenReturn((byte) Speed.SLOW);
+
+        int connectedSpeed = Speed.VERY_FAST;
+        RssiCurve connectedBadgeCurve = mock(RssiCurve.class);
+        Bundle attr1 = new Bundle();
+        attr1.putParcelable(ScoredNetwork.ATTRIBUTES_KEY_BADGING_CURVE, connectedBadgeCurve);
+        ScoredNetwork connectedScore = new ScoredNetwork(
+                NetworkKey.createFromScanResult(scanResultConnected),
+                connectedBadgeCurve,
+                false /* meteredHint */,
+                attr1);
+        when(mockWifiNetworkScoreCache.getScoredNetwork(scanResultConnected))
+                .thenReturn(connectedScore);
+        when(connectedBadgeCurve.lookupScore(anyInt())).thenReturn((byte) connectedSpeed);
 
         ap.update(mockWifiNetworkScoreCache, true /* scoringUiEnabled */);
 
-        verify(mockWifiNetworkScoreCache, times(2)).getScoredNetwork(key);
-        assertThat(ap.getSpeed()).isEqualTo(AccessPoint.Speed.FAST);
+        assertThat(ap.getSpeed()).isEqualTo(connectedSpeed);
     }
 
     @Test
@@ -562,8 +594,13 @@
     }
 
     private ScoredNetwork buildScoredNetworkWithMockBadgeCurve() {
+        return buildScoredNetworkWithGivenBadgeCurve(mockBadgeCurve);
+
+    }
+
+    private ScoredNetwork buildScoredNetworkWithGivenBadgeCurve(RssiCurve badgeCurve) {
         Bundle attr1 = new Bundle();
-        attr1.putParcelable(ScoredNetwork.ATTRIBUTES_KEY_BADGING_CURVE, mockBadgeCurve);
+        attr1.putParcelable(ScoredNetwork.ATTRIBUTES_KEY_BADGING_CURVE, badgeCurve);
         return new ScoredNetwork(
                 new NetworkKey(new WifiKey("\"ssid\"", "00:00:00:00:00:00")),
                 mockBadgeCurve,
@@ -574,19 +611,14 @@
 
     private AccessPoint createAccessPointWithScanResultCache() {
         Bundle bundle = new Bundle();
-        ArrayList<ScanResult> scanResults = buildScanResultCache();
-        bundle.putParcelableArrayList(AccessPoint.KEY_SCANRESULTCACHE, scanResults);
+        bundle.putParcelableArrayList(AccessPoint.KEY_SCANRESULTCACHE, SCAN_RESULTS);
         return new AccessPoint(mContext, bundle);
     }
 
-    private ArrayList<ScanResult> buildScanResultCache() {
+    private static ArrayList<ScanResult> buildScanResultCache() {
         ArrayList<ScanResult> scanResults = new ArrayList<>();
         for (int i = 0; i < 5; i++) {
-            ScanResult scanResult = new ScanResult();
-            scanResult.level = i;
-            scanResult.BSSID = "bssid-" + i;
-            scanResult.timestamp = SystemClock.elapsedRealtime() * 1000;
-            scanResult.capabilities = "";
+            ScanResult scanResult = createScanResult(TEST_SSID, "bssid-" + i, i);
             scanResults.add(scanResult);
         }
         return scanResults;
@@ -849,4 +881,87 @@
 
         ap.update(null, wifiInfo, networkInfo);
     }
+
+    @Test
+    public void testSpeedLabelAveragesAllBssidScores() {
+        AccessPoint ap = createAccessPointWithScanResultCache();
+
+        int speed1 = Speed.MODERATE;
+        RssiCurve badgeCurve1 = mock(RssiCurve.class);
+        when(badgeCurve1.lookupScore(anyInt())).thenReturn((byte) speed1);
+        when(mockWifiNetworkScoreCache.getScoredNetwork(SCAN_RESULTS.get(0)))
+                .thenReturn(buildScoredNetworkWithGivenBadgeCurve(badgeCurve1));
+        int speed2 = Speed.VERY_FAST;
+        RssiCurve badgeCurve2 = mock(RssiCurve.class);
+        when(badgeCurve2.lookupScore(anyInt())).thenReturn((byte) speed2);
+        when(mockWifiNetworkScoreCache.getScoredNetwork(SCAN_RESULTS.get(1)))
+                .thenReturn(buildScoredNetworkWithGivenBadgeCurve(badgeCurve2));
+
+        int expectedSpeed = (speed1 + speed2) / 2;
+
+        ap.update(mockWifiNetworkScoreCache, true /* scoringUiEnabled */);
+
+        assertThat(ap.getSpeed()).isEqualTo(expectedSpeed);
+    }
+
+    @Test
+    public void testSpeedLabelAverageIgnoresNoSpeedScores() {
+        AccessPoint ap = createAccessPointWithScanResultCache();
+
+        int speed1 = Speed.VERY_FAST;
+        RssiCurve badgeCurve1 = mock(RssiCurve.class);
+        when(badgeCurve1.lookupScore(anyInt())).thenReturn((byte) speed1);
+        when(mockWifiNetworkScoreCache.getScoredNetwork(SCAN_RESULTS.get(0)))
+                .thenReturn(buildScoredNetworkWithGivenBadgeCurve(badgeCurve1));
+        int speed2 = Speed.NONE;
+        RssiCurve badgeCurve2 = mock(RssiCurve.class);
+        when(badgeCurve2.lookupScore(anyInt())).thenReturn((byte) speed2);
+        when(mockWifiNetworkScoreCache.getScoredNetwork(SCAN_RESULTS.get(1)))
+                .thenReturn(buildScoredNetworkWithGivenBadgeCurve(badgeCurve2));
+
+        ap.update(mockWifiNetworkScoreCache, true /* scoringUiEnabled */);
+
+        assertThat(ap.getSpeed()).isEqualTo(speed1);
+    }
+
+    @Test
+    public void testSpeedLabelUsesFallbackScoreWhenConnectedAccessPointScoreUnavailable() {
+        int rssi = -55;
+        String bssid = "00:00:00:00:00:00";
+        int networkId = 123;
+
+        WifiInfo info = new WifiInfo();
+        info.setRssi(rssi);
+        info.setSSID(WifiSsid.createFromAsciiEncoded(TEST_SSID));
+        info.setBSSID(bssid);
+        info.setNetworkId(networkId);
+
+        ArrayList<ScanResult> scanResults = new ArrayList<>();
+        ScanResult scanResultUnconnected = createScanResult(TEST_SSID, "11:11:11:11:11:11", rssi);
+        scanResults.add(scanResultUnconnected);
+
+        ScanResult scanResultConnected = createScanResult(TEST_SSID, bssid, rssi);
+        scanResults.add(scanResultConnected);
+
+        AccessPoint ap =
+                new TestAccessPointBuilder(mContext)
+                        .setActive(true)
+                        .setNetworkId(networkId)
+                        .setSsid(TEST_SSID)
+                        .setScanResultCache(scanResults)
+                        .setWifiInfo(info)
+                        .build();
+
+        int fallbackSpeed = Speed.SLOW;
+        when(mockWifiNetworkScoreCache.getScoredNetwork(scanResultUnconnected))
+                .thenReturn(buildScoredNetworkWithMockBadgeCurve());
+        when(mockBadgeCurve.lookupScore(anyInt())).thenReturn((byte) fallbackSpeed);
+
+        when(mockWifiNetworkScoreCache.getScoredNetwork(scanResultConnected))
+                .thenReturn(null);
+
+        ap.update(mockWifiNetworkScoreCache, true /* scoringUiEnabled */);
+
+        assertThat(ap.getSpeed()).isEqualTo(fallbackSpeed);
+    }
 }