Adding a per-uid cap on concurrent alarms

Reasonably, it is seen that most apps have a low number of alarms inside
alarm manager at any given time. In rare cases, app bugs can lead to
spam resulting in memory and computation pressure on the system.

(Also, removing some unused code)

Test: atest FrameworksMockingServicesTests:AlarmManagerServiceTest
atest CtsAlarmManagerTestCases

Bug: 31832077
Change-Id: Ie537e644619a787b7a730dc567aafa0258ae57f0
diff --git a/services/core/java/com/android/server/AlarmManagerService.java b/services/core/java/com/android/server/AlarmManagerService.java
index a400cc3..3d918fc 100644
--- a/services/core/java/com/android/server/AlarmManagerService.java
+++ b/services/core/java/com/android/server/AlarmManagerService.java
@@ -83,6 +83,7 @@
 import android.util.Slog;
 import android.util.SparseArray;
 import android.util.SparseBooleanArray;
+import android.util.SparseIntArray;
 import android.util.SparseLongArray;
 import android.util.StatsLog;
 import android.util.TimeUtils;
@@ -195,6 +196,7 @@
     private final Injector mInjector;
     int mBroadcastRefCount = 0;
     PowerManager.WakeLock mWakeLock;
+    SparseIntArray mAlarmsPerUid = new SparseIntArray();
     ArrayList<Alarm> mPendingNonWakeupAlarms = new ArrayList<>();
     ArrayList<InFlight> mInFlight = new ArrayList<>();
     private final ArrayList<AlarmManagerInternal.InFlightListener> mInFlightListeners =
@@ -393,6 +395,8 @@
         @VisibleForTesting
         static final String KEY_LISTENER_TIMEOUT = "listener_timeout";
         @VisibleForTesting
+        static final String KEY_MAX_ALARMS_PER_UID = "max_alarms_per_uid";
+        @VisibleForTesting
         static final String KEY_APP_STANDBY_QUOTAS_ENABLED = "app_standby_quotas_enabled";
         private static final String KEY_APP_STANDBY_WINDOW = "app_standby_window";
         @VisibleForTesting
@@ -420,6 +424,7 @@
         private static final long DEFAULT_ALLOW_WHILE_IDLE_LONG_TIME = 9*60*1000;
         private static final long DEFAULT_ALLOW_WHILE_IDLE_WHITELIST_DURATION = 10*1000;
         private static final long DEFAULT_LISTENER_TIMEOUT = 5 * 1000;
+        private static final int DEFAULT_MAX_ALARMS_PER_UID = 500;
         private static final boolean DEFAULT_APP_STANDBY_QUOTAS_ENABLED = true;
         private static final long DEFAULT_APP_STANDBY_WINDOW = 60 * 60 * 1000;  // 1 hr
         /**
@@ -461,6 +466,8 @@
 
         // Direct alarm listener callback timeout
         public long LISTENER_TIMEOUT = DEFAULT_LISTENER_TIMEOUT;
+        public int MAX_ALARMS_PER_UID = DEFAULT_MAX_ALARMS_PER_UID;
+
         public boolean APP_STANDBY_QUOTAS_ENABLED = DEFAULT_APP_STANDBY_QUOTAS_ENABLED;
 
         public long APP_STANDBY_WINDOW = DEFAULT_APP_STANDBY_WINDOW;
@@ -549,6 +556,15 @@
                     APP_STANDBY_QUOTAS[i] = mParser.getInt(KEYS_APP_STANDBY_QUOTAS[i],
                             Math.min(APP_STANDBY_QUOTAS[i - 1], DEFAULT_APP_STANDBY_QUOTAS[i]));
                 }
+
+                MAX_ALARMS_PER_UID = mParser.getInt(KEY_MAX_ALARMS_PER_UID,
+                        DEFAULT_MAX_ALARMS_PER_UID);
+                if (MAX_ALARMS_PER_UID < DEFAULT_MAX_ALARMS_PER_UID) {
+                    Slog.w(TAG, "Cannot set " + KEY_MAX_ALARMS_PER_UID + " lower than "
+                            + DEFAULT_MAX_ALARMS_PER_UID);
+                    MAX_ALARMS_PER_UID = DEFAULT_MAX_ALARMS_PER_UID;
+                }
+
                 updateAllowWhileIdleWhitelistDurationLocked();
             }
         }
@@ -590,6 +606,9 @@
             TimeUtils.formatDuration(ALLOW_WHILE_IDLE_WHITELIST_DURATION, pw);
             pw.println();
 
+            pw.print(KEY_MAX_ALARMS_PER_UID); pw.print("=");
+            pw.println(MAX_ALARMS_PER_UID);
+
             for (int i = 0; i < KEYS_APP_STANDBY_DELAY.length; i++) {
                 pw.print(KEYS_APP_STANDBY_DELAY[i]); pw.print("=");
                 TimeUtils.formatDuration(APP_STANDBY_MIN_DELAYS[i], pw);
@@ -671,12 +690,6 @@
 
         final ArrayList<Alarm> alarms = new ArrayList<Alarm>();
 
-        Batch() {
-            start = 0;
-            end = Long.MAX_VALUE;
-            flags = 0;
-        }
-
         Batch(Alarm seed) {
             start = seed.whenElapsed;
             end = clampPositive(seed.maxWhenElapsed);
@@ -728,11 +741,16 @@
             return newStart;
         }
 
+        /**
+         * Remove an alarm from this batch.
+         * <p> <b> Should be used only while re-ordering the alarm within the service </b> as it
+         * does not update {@link #mAlarmsPerUid}
+         */
         boolean remove(Alarm alarm) {
-            return remove(a -> (a == alarm));
+            return remove(a -> (a == alarm), true);
         }
 
-        boolean remove(Predicate<Alarm> predicate) {
+        boolean remove(Predicate<Alarm> predicate, boolean reOrdering) {
             boolean didRemove = false;
             long newStart = 0;  // recalculate endpoints as we go
             long newEnd = Long.MAX_VALUE;
@@ -741,6 +759,9 @@
                 Alarm alarm = alarms.get(i);
                 if (predicate.test(alarm)) {
                     alarms.remove(i);
+                    if (!reOrdering) {
+                        decrementAlarmCount(alarm.uid);
+                    }
                     didRemove = true;
                     if (alarm.alarmClock != null) {
                         mNextAlarmClockMayChange = true;
@@ -1734,6 +1755,15 @@
                         + " tElapsed=" + triggerElapsed + " maxElapsed=" + maxElapsed
                         + " interval=" + interval + " flags=0x" + Integer.toHexString(flags));
             }
+            if (mAlarmsPerUid.get(callingUid, 0) >= mConstants.MAX_ALARMS_PER_UID) {
+                final String errorMsg =
+                        "Maximum limit of concurrent alarms " + mConstants.MAX_ALARMS_PER_UID
+                                + " reached for uid: " + UserHandle.formatUid(callingUid)
+                                + ", callingPackage: " + callingPackage;
+                // STOPSHIP (b/128866264): Just to catch breakages. Remove before final release.
+                Slog.wtf(TAG, errorMsg);
+                throw new UnsupportedOperationException(errorMsg);
+            }
             setImplLocked(type, triggerAtTime, triggerElapsed, windowLength, maxElapsed,
                     interval, operation, directReceiver, listenerTag, flags, true, workSource,
                     alarmClock, callingUid, callingPackage);
@@ -1756,6 +1786,7 @@
         } catch (RemoteException e) {
         }
         removeLocked(operation, directReceiver);
+        incrementAlarmCount(a.uid);
         setImplLocked(a, false, doValidate);
     }
 
@@ -2197,7 +2228,6 @@
                             ? sdf.format(new Date(nowRTC - (nowELAPSED - time)))
                             : "-");
                 } while (i != mNextTickHistory);
-                pw.println();
             }
 
             SystemServiceManager ssm = LocalServices.getService(SystemServiceManager.class);
@@ -2305,6 +2335,18 @@
             if (!blocked) {
                 pw.println("    none");
             }
+            pw.println();
+            pw.print("  Pending alarms per uid: [");
+            for (int i = 0; i < mAlarmsPerUid.size(); i++) {
+                if (i > 0) {
+                    pw.print(", ");
+                }
+                UserHandle.formatUid(pw, mAlarmsPerUid.keyAt(i));
+                pw.print(":");
+                pw.print(mAlarmsPerUid.valueAt(i));
+            }
+            pw.println("]");
+            pw.println();
 
             mAppWakeupHistory.dump(pw, "  ", nowELAPSED);
 
@@ -3046,7 +3088,7 @@
         final Predicate<Alarm> whichAlarms = (Alarm a) -> a.matches(operation, directReceiver);
         for (int i = mAlarmBatches.size() - 1; i >= 0; i--) {
             Batch b = mAlarmBatches.get(i);
-            didRemove |= b.remove(whichAlarms);
+            didRemove |= b.remove(whichAlarms, false);
             if (b.size() == 0) {
                 mAlarmBatches.remove(i);
             }
@@ -3098,7 +3140,7 @@
         final Predicate<Alarm> whichAlarms = (Alarm a) -> a.uid == uid;
         for (int i = mAlarmBatches.size() - 1; i >= 0; i--) {
             Batch b = mAlarmBatches.get(i);
-            didRemove |= b.remove(whichAlarms);
+            didRemove |= b.remove(whichAlarms, false);
             if (b.size() == 0) {
                 mAlarmBatches.remove(i);
             }
@@ -3145,7 +3187,7 @@
         final boolean oldHasTick = haveBatchesTimeTickAlarm(mAlarmBatches);
         for (int i = mAlarmBatches.size() - 1; i >= 0; i--) {
             Batch b = mAlarmBatches.get(i);
-            didRemove |= b.remove(whichAlarms);
+            didRemove |= b.remove(whichAlarms, false);
             if (b.size() == 0) {
                 mAlarmBatches.remove(i);
             }
@@ -3183,6 +3225,7 @@
         }
     }
 
+    // Only called for ephemeral apps
     void removeForStoppedLocked(final int uid) {
         if (uid == Process.SYSTEM_UID) {
             Slog.wtf(TAG, "removeForStoppedLocked: Shouldn't for UID=" + uid);
@@ -3200,7 +3243,7 @@
         };
         for (int i = mAlarmBatches.size() - 1; i >= 0; i--) {
             Batch b = mAlarmBatches.get(i);
-            didRemove |= b.remove(whichAlarms);
+            didRemove |= b.remove(whichAlarms, false);
             if (b.size() == 0) {
                 mAlarmBatches.remove(i);
             }
@@ -3237,7 +3280,7 @@
                 (Alarm a) -> UserHandle.getUserId(a.creatorUid) == userHandle;
         for (int i = mAlarmBatches.size() - 1; i >= 0; i--) {
             Batch b = mAlarmBatches.get(i);
-            didRemove |= b.remove(whichAlarms);
+            didRemove |= b.remove(whichAlarms, false);
             if (b.size() == 0) {
                 mAlarmBatches.remove(i);
             }
@@ -3407,8 +3450,7 @@
         return mConstants.ALLOW_WHILE_IDLE_LONG_TIME;
     }
 
-    boolean triggerAlarmsLocked(ArrayList<Alarm> triggerList, final long nowELAPSED,
-            final long nowRTC) {
+    boolean triggerAlarmsLocked(ArrayList<Alarm> triggerList, final long nowELAPSED) {
         boolean hasWakeup = false;
         // batches are temporally sorted, so we need only pull from the
         // start of the list until we either empty it or hit a batch
@@ -3984,7 +4026,7 @@
                         }
 
                         mLastTrigger = nowELAPSED;
-                        boolean hasWakeup = triggerAlarmsLocked(triggerList, nowELAPSED, nowRTC);
+                        boolean hasWakeup = triggerAlarmsLocked(triggerList, nowELAPSED);
                         if (!hasWakeup && checkAllowNonWakeupDelayLocked(nowELAPSED)) {
                             // if there are no wakeup alarms and the screen is off, we can
                             // delay what we have so far until the future.
@@ -4108,9 +4150,8 @@
                 case ALARM_EVENT: {
                     ArrayList<Alarm> triggerList = new ArrayList<Alarm>();
                     synchronized (mLock) {
-                        final long nowRTC = mInjector.getCurrentTimeMillis();
                         final long nowELAPSED = mInjector.getElapsedRealtime();
-                        triggerAlarmsLocked(triggerList, nowELAPSED, nowRTC);
+                        triggerAlarmsLocked(triggerList, nowELAPSED);
                         updateNextAlarmClockLocked();
                     }
 
@@ -4719,7 +4760,7 @@
                 mAppWakeupHistory.recordAlarmForPackage(alarm.sourcePackage,
                         UserHandle.getUserId(alarm.creatorUid), nowELAPSED);
             }
-
+            decrementAlarmCount(alarm.uid);
             final BroadcastStats bs = inflight.mBroadcastStats;
             bs.count++;
             if (bs.nesting == 0) {
@@ -4747,6 +4788,27 @@
         }
     }
 
+    private void incrementAlarmCount(int uid) {
+        final int uidIndex = mAlarmsPerUid.indexOfKey(uid);
+        if (uidIndex >= 0) {
+            mAlarmsPerUid.setValueAt(uidIndex, mAlarmsPerUid.valueAt(uidIndex) + 1);
+        } else {
+            mAlarmsPerUid.put(uid, 1);
+        }
+    }
+
+    private void decrementAlarmCount(int uid) {
+        final int uidIndex = mAlarmsPerUid.indexOfKey(uid);
+        if (uidIndex >= 0) {
+            final int newCount = mAlarmsPerUid.valueAt(uidIndex) - 1;
+            if (newCount > 0) {
+                mAlarmsPerUid.setValueAt(uidIndex, newCount);
+            } else {
+                mAlarmsPerUid.removeAt(uidIndex);
+            }
+        }
+    }
+
     private class ShellCmd extends ShellCommand {
 
         IAlarmManager getBinderService() {
diff --git a/services/tests/mockingservicestests/src/com/android/server/AlarmManagerServiceTest.java b/services/tests/mockingservicestests/src/com/android/server/AlarmManagerServiceTest.java
index 6386b3b3..7c91b64 100644
--- a/services/tests/mockingservicestests/src/com/android/server/AlarmManagerServiceTest.java
+++ b/services/tests/mockingservicestests/src/com/android/server/AlarmManagerServiceTest.java
@@ -15,7 +15,10 @@
  */
 package com.android.server;
 
+import static android.app.AlarmManager.ELAPSED_REALTIME;
 import static android.app.AlarmManager.ELAPSED_REALTIME_WAKEUP;
+import static android.app.AlarmManager.RTC;
+import static android.app.AlarmManager.RTC_WAKEUP;
 import static android.app.usage.UsageStatsManager.STANDBY_BUCKET_ACTIVE;
 import static android.app.usage.UsageStatsManager.STANDBY_BUCKET_FREQUENT;
 import static android.app.usage.UsageStatsManager.STANDBY_BUCKET_RARE;
@@ -254,14 +257,22 @@
     }
 
     private void setTestAlarm(int type, long triggerTime, PendingIntent operation) {
+        setTestAlarm(type, triggerTime, operation, TEST_CALLING_UID);
+    }
+
+    private void setTestAlarm(int type, long triggerTime, PendingIntent operation, int callingUid) {
         mService.setImpl(type, triggerTime, AlarmManager.WINDOW_EXACT, 0,
                 operation, null, "test", AlarmManager.FLAG_STANDALONE, null, null,
-                TEST_CALLING_UID, TEST_CALLING_PACKAGE);
+                callingUid, TEST_CALLING_PACKAGE);
     }
 
     private PendingIntent getNewMockPendingIntent() {
+        return getNewMockPendingIntent(TEST_CALLING_UID);
+    }
+
+    private PendingIntent getNewMockPendingIntent(int mockUid) {
         final PendingIntent mockPi = mock(PendingIntent.class, Answers.RETURNS_DEEP_STUBS);
-        when(mockPi.getCreatorUid()).thenReturn(TEST_CALLING_UID);
+        when(mockPi.getCreatorUid()).thenReturn(mockUid);
         when(mockPi.getCreatorPackage()).thenReturn(TEST_CALLING_PACKAGE);
         return mockPi;
     }
@@ -724,6 +735,91 @@
         verify(mMockContext).sendBroadcastAsUser(mService.mTimeTickIntent, UserHandle.ALL);
     }
 
+    @Test
+    public void alarmCountKeyedOnCallingUid() {
+        final int mockCreatorUid = 431412;
+        final PendingIntent pi = getNewMockPendingIntent(mockCreatorUid);
+        setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + 5, pi);
+        assertEquals(1, mService.mAlarmsPerUid.get(TEST_CALLING_UID));
+        assertEquals(-1, mService.mAlarmsPerUid.get(mockCreatorUid, -1));
+    }
+
+    @Test
+    public void alarmCountOnSet() {
+        final int numAlarms = 103;
+        final int[] types = {RTC_WAKEUP, RTC, ELAPSED_REALTIME_WAKEUP, ELAPSED_REALTIME};
+        for (int i = 1; i <= numAlarms; i++) {
+            setTestAlarm(types[i % 4], mNowElapsedTest + i, getNewMockPendingIntent());
+            assertEquals(i, mService.mAlarmsPerUid.get(TEST_CALLING_UID));
+        }
+    }
+
+    @Test
+    public void alarmCountOnExpiration() throws InterruptedException {
+        final int numAlarms = 8; // This test is slow
+        for (int i = 0; i < numAlarms; i++) {
+            setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + i + 10, getNewMockPendingIntent());
+        }
+        int expired = 0;
+        while (expired < numAlarms) {
+            mNowElapsedTest = mTestTimer.getElapsed();
+            mTestTimer.expire();
+            expired++;
+            assertEquals(numAlarms - expired, mService.mAlarmsPerUid.get(TEST_CALLING_UID, 0));
+        }
+    }
+
+    @Test
+    public void alarmCountOnUidRemoved() {
+        final int numAlarms = 10;
+        for (int i = 0; i < numAlarms; i++) {
+            setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + i + 10, getNewMockPendingIntent());
+        }
+        assertEquals(numAlarms, mService.mAlarmsPerUid.get(TEST_CALLING_UID));
+        mService.removeLocked(TEST_CALLING_UID);
+        assertEquals(0, mService.mAlarmsPerUid.get(TEST_CALLING_UID, 0));
+    }
+
+    @Test
+    public void alarmCountOnPackageRemoved() {
+        final int numAlarms = 10;
+        for (int i = 0; i < numAlarms; i++) {
+            setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + i + 10, getNewMockPendingIntent());
+        }
+        assertEquals(numAlarms, mService.mAlarmsPerUid.get(TEST_CALLING_UID));
+        mService.removeLocked(TEST_CALLING_PACKAGE);
+        assertEquals(0, mService.mAlarmsPerUid.get(TEST_CALLING_UID, 0));
+    }
+
+    @Test
+    public void alarmCountOnUserRemoved() {
+        final int mockUserId = 15;
+        final int numAlarms = 10;
+        for (int i = 0; i < numAlarms; i++) {
+            int mockUid = UserHandle.getUid(mockUserId, 1234 + i);
+            setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + i + 10,
+                    getNewMockPendingIntent(mockUid), mockUid);
+        }
+        assertEquals(numAlarms, mService.mAlarmsPerUid.size());
+        mService.removeUserLocked(mockUserId);
+        assertEquals(0, mService.mAlarmsPerUid.size());
+    }
+
+    @Test
+    public void alarmCountOnAlarmRemoved() {
+        final int numAlarms = 10;
+        final PendingIntent[] pis = new PendingIntent[numAlarms];
+        for (int i = 0; i < numAlarms; i++) {
+            pis[i] = getNewMockPendingIntent();
+            setTestAlarm(ELAPSED_REALTIME, mNowElapsedTest + i + 5, pis[i]);
+        }
+        assertEquals(numAlarms, mService.mAlarmsPerUid.get(TEST_CALLING_UID));
+        for (int i = 0; i < numAlarms; i++) {
+            mService.removeLocked(pis[i], null);
+            assertEquals(numAlarms - i - 1, mService.mAlarmsPerUid.get(TEST_CALLING_UID, 0));
+        }
+    }
+
     @After
     public void tearDown() {
         if (mMockingSession != null) {