Merge "Prevent abuse of MediaRoute2ProviderService#notifyRequestFailed()" into rvc-dev
diff --git a/media/java/android/media/MediaRoute2ProviderService.java b/media/java/android/media/MediaRoute2ProviderService.java
index 05c6e3a..908fd82 100644
--- a/media/java/android/media/MediaRoute2ProviderService.java
+++ b/media/java/android/media/MediaRoute2ProviderService.java
@@ -40,8 +40,10 @@
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Deque;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -132,15 +134,21 @@
@Retention(RetentionPolicy.SOURCE)
public @interface Reason {}
+ private static final int MAX_REQUEST_IDS_SIZE = 500;
+
private final Handler mHandler;
private final Object mSessionLock = new Object();
+ private final Object mRequestIdsLock = new Object();
private final AtomicBoolean mStatePublishScheduled = new AtomicBoolean(false);
private MediaRoute2ProviderServiceStub mStub;
private IMediaRoute2ProviderServiceCallback mRemoteCallback;
private volatile MediaRoute2ProviderInfo mProviderInfo;
+ @GuardedBy("mRequestIdsLock")
+ private final Deque<Long> mRequestIds = new ArrayDeque<>(MAX_REQUEST_IDS_SIZE);
+
@GuardedBy("mSessionLock")
- private ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();
+ private final ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();
public MediaRoute2ProviderService() {
mHandler = new Handler(Looper.getMainLooper());
@@ -230,6 +238,11 @@
@NonNull RoutingSessionInfo sessionInfo) {
Objects.requireNonNull(sessionInfo, "sessionInfo must not be null");
+ if (requestId != REQUEST_ID_NONE && !removeRequestId(requestId)) {
+ Log.w(TAG, "notifySessionCreated: The requestId doesn't exist. requestId=" + requestId);
+ return;
+ }
+
String sessionId = sessionInfo.getId();
synchronized (mSessionLock) {
if (mSessionInfo.containsKey(sessionId)) {
@@ -322,6 +335,13 @@
if (mRemoteCallback == null) {
return;
}
+
+ if (!removeRequestId(requestId)) {
+ Log.w(TAG, "notifyRequestFailed: The requestId doesn't exist. requestId="
+ + requestId);
+ return;
+ }
+
try {
mRemoteCallback.notifyRequestFailed(requestId, reason);
} catch (RemoteException ex) {
@@ -469,6 +489,36 @@
}
}
+ /**
+ * Adds a requestId in the request ID list whose max size is {@link #MAX_REQUEST_IDS_SIZE}.
+ * When the max size is reached, the first element is removed (FIFO).
+ */
+ private void addRequestId(long requestId) {
+ synchronized (mRequestIdsLock) {
+ if (mRequestIds.size() >= MAX_REQUEST_IDS_SIZE) {
+ mRequestIds.removeFirst();
+ }
+ mRequestIds.addLast(requestId);
+ }
+ }
+
+ /**
+ * Removes the given {@code requestId} from received request ID list.
+ * <p>
+ * Returns whether the list contains the {@code requestId}. These are the cases when the list
+ * doesn't contain the given {@code requestId}:
+ * <ul>
+ * <li>This service has never received a request with the requestId. </li>
+ * <li>{@link #notifyRequestFailed} or {@link #notifySessionCreated} already has been called
+ * for the requestId. </li>
+ * </ul>
+ */
+ private boolean removeRequestId(long requestId) {
+ synchronized (mRequestIdsLock) {
+ return mRequestIds.removeFirstOccurrence(requestId);
+ }
+ }
+
final class MediaRoute2ProviderServiceStub extends IMediaRoute2ProviderService.Stub {
MediaRoute2ProviderServiceStub() { }
@@ -529,6 +579,7 @@
if (!checkRouteIdIsValid(routeId, "setRouteVolume")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetRouteVolume,
MediaRoute2ProviderService.this, requestId, routeId, volume));
}
@@ -542,6 +593,7 @@
if (!checkRouteIdIsValid(routeId, "requestCreateSession")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onCreateSession,
MediaRoute2ProviderService.this, requestId, packageName, routeId,
requestCreateSession));
@@ -556,6 +608,7 @@
|| !checkRouteIdIsValid(routeId, "selectRoute")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSelectRoute,
MediaRoute2ProviderService.this, requestId, sessionId, routeId));
}
@@ -569,6 +622,7 @@
|| !checkRouteIdIsValid(routeId, "deselectRoute")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onDeselectRoute,
MediaRoute2ProviderService.this, requestId, sessionId, routeId));
}
@@ -582,6 +636,7 @@
|| !checkRouteIdIsValid(routeId, "transferToRoute")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onTransferToRoute,
MediaRoute2ProviderService.this, requestId, sessionId, routeId));
}
@@ -594,6 +649,7 @@
if (!checkSessionIdIsValid(sessionId, "setSessionVolume")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetSessionVolume,
MediaRoute2ProviderService.this, requestId, sessionId, volume));
}
@@ -606,6 +662,7 @@
if (!checkSessionIdIsValid(sessionId, "releaseSession")) {
return;
}
+ addRequestId(requestId);
mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onReleaseSession,
MediaRoute2ProviderService.this, requestId, sessionId));
}
diff --git a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
index 9575581..638a842 100644
--- a/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
+++ b/media/tests/MediaRouter/src/com/android/mediaroutertest/MediaRouter2ManagerTest.java
@@ -593,11 +593,16 @@
final int failureReason = REASON_REJECTED;
final CountDownLatch onRequestFailedLatch = new CountDownLatch(1);
+ final CountDownLatch onRequestFailedSecondCallLatch = new CountDownLatch(1);
addManagerCallback(new MediaRouter2Manager.Callback() {
@Override
public void onRequestFailed(int reason) {
if (reason == failureReason) {
- onRequestFailedLatch.countDown();
+ if (onRequestFailedLatch.getCount() > 0) {
+ onRequestFailedLatch.countDown();
+ } else {
+ onRequestFailedSecondCallLatch.countDown();
+ }
}
}
});
@@ -609,6 +614,11 @@
final long validRequestId = requestIds.get(0);
instance.notifyRequestFailed(validRequestId, failureReason);
assertTrue(onRequestFailedLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
+
+ // Test calling notifyRequestFailed() multiple times with the same valid requestId.
+ // onRequestFailed() shouldn't be called since the requestId has been already handled.
+ instance.notifyRequestFailed(validRequestId, failureReason);
+ assertFalse(onRequestFailedSecondCallLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
}
@Test