Merge "Prevent abuse of MediaRoute2ProviderService#notifyRequestFailed()" into rvc-dev
diff --git a/media/java/android/media/MediaRoute2ProviderService.java b/media/java/android/media/MediaRoute2ProviderService.java
index 05c6e3a..908fd82 100644
--- a/media/java/android/media/MediaRoute2ProviderService.java
+++ b/media/java/android/media/MediaRoute2ProviderService.java
@@ -40,8 +40,10 @@
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Deque;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -132,15 +134,21 @@
     @Retention(RetentionPolicy.SOURCE)
     public @interface Reason {}
 
+    private static final int MAX_REQUEST_IDS_SIZE = 500;
+
     private final Handler mHandler;
     private final Object mSessionLock = new Object();
+    private final Object mRequestIdsLock = new Object();
     private final AtomicBoolean mStatePublishScheduled = new AtomicBoolean(false);
     private MediaRoute2ProviderServiceStub mStub;
     private IMediaRoute2ProviderServiceCallback mRemoteCallback;
     private volatile MediaRoute2ProviderInfo mProviderInfo;
 
+    @GuardedBy("mRequestIdsLock")
+    private final Deque<Long> mRequestIds = new ArrayDeque<>(MAX_REQUEST_IDS_SIZE);
+
     @GuardedBy("mSessionLock")
-    private ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();
+    private final ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();
 
     public MediaRoute2ProviderService() {
         mHandler = new Handler(Looper.getMainLooper());
@@ -230,6 +238,11 @@
             @NonNull RoutingSessionInfo sessionInfo) {
         Objects.requireNonNull(sessionInfo, "sessionInfo must not be null");
 
+        if (requestId != REQUEST_ID_NONE && !removeRequestId(requestId)) {
+            Log.w(TAG, "notifySessionCreated: The requestId doesn't exist. requestId=" + requestId);
+            return;
+        }
+
         String sessionId = sessionInfo.getId();
         synchronized (mSessionLock) {
             if (mSessionInfo.containsKey(sessionId)) {
@@ -322,6 +335,13 @@
         if (mRemoteCallback == null) {
             return;
         }
+
+        if (!removeRequestId(requestId)) {
+            Log.w(TAG, "notifyRequestFailed: The requestId doesn't exist. requestId="
+                    + requestId);
+            return;
+        }
+
         try {
             mRemoteCallback.notifyRequestFailed(requestId, reason);
         } catch (RemoteException ex) {
@@ -469,6 +489,36 @@
         }
     }
 
+    /**
+     * Adds a requestId in the request ID list whose max size is {@link #MAX_REQUEST_IDS_SIZE}.
+     * When the max size is reached, the first element is removed (FIFO).
+     */
+    private void addRequestId(long requestId) {
+        synchronized (mRequestIdsLock) {
+            if (mRequestIds.size() >= MAX_REQUEST_IDS_SIZE) {
+                mRequestIds.removeFirst();
+            }
+            mRequestIds.addLast(requestId);
+        }
+    }
+
+    /**
+     * Removes the given {@code requestId} from received request ID list.
+     * <p>
+     * Returns whether the list contains the {@code requestId}. These are the cases when the list
+     * doesn't contain the given {@code requestId}:
+     * <ul>
+     *     <li>This service has never received a request with the requestId. </li>
+     *     <li>{@link #notifyRequestFailed} or {@link #notifySessionCreated} already has been called
+     *         for the requestId. </li>
+     * </ul>
+     */
+    private boolean removeRequestId(long requestId) {
+        synchronized (mRequestIdsLock) {
+            return mRequestIds.removeFirstOccurrence(requestId);
+        }
+    }
+
     final class MediaRoute2ProviderServiceStub extends IMediaRoute2ProviderService.Stub {
         MediaRoute2ProviderServiceStub() { }
 
@@ -529,6 +579,7 @@
             if (!checkRouteIdIsValid(routeId, "setRouteVolume")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetRouteVolume,
                     MediaRoute2ProviderService.this, requestId, routeId, volume));
         }
@@ -542,6 +593,7 @@
             if (!checkRouteIdIsValid(routeId, "requestCreateSession")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onCreateSession,
                     MediaRoute2ProviderService.this, requestId, packageName, routeId,
                     requestCreateSession));
@@ -556,6 +608,7 @@
                     || !checkRouteIdIsValid(routeId, "selectRoute")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSelectRoute,
                     MediaRoute2ProviderService.this, requestId, sessionId, routeId));
         }
@@ -569,6 +622,7 @@
                     || !checkRouteIdIsValid(routeId, "deselectRoute")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onDeselectRoute,
                     MediaRoute2ProviderService.this, requestId, sessionId, routeId));
         }
@@ -582,6 +636,7 @@
                     || !checkRouteIdIsValid(routeId, "transferToRoute")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onTransferToRoute,
                     MediaRoute2ProviderService.this, requestId, sessionId, routeId));
         }
@@ -594,6 +649,7 @@
             if (!checkSessionIdIsValid(sessionId, "setSessionVolume")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetSessionVolume,
                     MediaRoute2ProviderService.this, requestId, sessionId, volume));
         }
@@ -606,6 +662,7 @@
             if (!checkSessionIdIsValid(sessionId, "releaseSession")) {
                 return;
             }
+            addRequestId(requestId);
             mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onReleaseSession,
                     MediaRoute2ProviderService.this, requestId, sessionId));
         }
diff --git a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
index 9575581..638a842 100644
--- a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
+++ b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
@@ -593,11 +593,16 @@
 
         final int failureReason = REASON_REJECTED;
         final CountDownLatch onRequestFailedLatch = new CountDownLatch(1);
+        final CountDownLatch onRequestFailedSecondCallLatch = new CountDownLatch(1);
         addManagerCallback(new MediaRouter2Manager.Callback() {
             @Override
             public void onRequestFailed(int reason) {
                 if (reason == failureReason) {
-                    onRequestFailedLatch.countDown();
+                    if (onRequestFailedLatch.getCount() > 0) {
+                        onRequestFailedLatch.countDown();
+                    } else {
+                        onRequestFailedSecondCallLatch.countDown();
+                    }
                 }
             }
         });
@@ -609,6 +614,11 @@
         final long validRequestId = requestIds.get(0);
         instance.notifyRequestFailed(validRequestId, failureReason);
         assertTrue(onRequestFailedLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+
+        // Test calling notifyRequestFailed() multiple times with the same valid requestId.
+        // onRequestFailed() shouldn't be called since the requestId has been already handled.
+        instance.notifyRequestFailed(validRequestId, failureReason);
+        assertFalse(onRequestFailedSecondCallLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
     }
 
     @Test