Merge "Allow caller holding BIND_CONTENT_SUGGESTIONS_SERVICE to make suggestion calls" into qt-dev
diff --git a/services/contentsuggestions/java/com/android/server/contentsuggestions/ContentSuggestionsManagerService.java b/services/contentsuggestions/java/com/android/server/contentsuggestions/ContentSuggestionsManagerService.java
index 55a0621..ecea251c 100644
--- a/services/contentsuggestions/java/com/android/server/contentsuggestions/ContentSuggestionsManagerService.java
+++ b/services/contentsuggestions/java/com/android/server/contentsuggestions/ContentSuggestionsManagerService.java
@@ -16,7 +16,9 @@
 
 package com.android.server.contentsuggestions;
 
+import static android.Manifest.permission.BIND_CONTENT_SUGGESTIONS_SERVICE;
 import static android.Manifest.permission.MANAGE_CONTENT_SUGGESTIONS;
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -92,16 +94,11 @@
         return MAX_TEMP_SERVICE_DURATION_MS;
     }
 
-    private boolean isCallerRecents(int userId) {
-        if (mServiceNameResolver.isTemporary(userId)) {
-            // If a temporary service is set then skip the recents check
-            return true;
-        }
-        return mActivityTaskManagerInternal.isCallerRecents(Binder.getCallingUid());
-    }
-
-    private void enforceCallerIsRecents(int userId, String func) {
-        if (isCallerRecents(userId)) {
+    private void enforceCaller(int userId, String func) {
+        Context ctx = getContext();
+        if (ctx.checkCallingPermission(BIND_CONTENT_SUGGESTIONS_SERVICE) == PERMISSION_GRANTED
+                || mServiceNameResolver.isTemporary(userId)
+                || mActivityTaskManagerInternal.isCallerRecents(Binder.getCallingUid())) {
             return;
         }
 
@@ -122,7 +119,7 @@
             if (imageContextRequestExtras == null) {
                 throw new IllegalArgumentException("Expected non-null imageContextRequestExtras");
             }
-            enforceCallerIsRecents(UserHandle.getCallingUserId(), "provideContextImage");
+            enforceCaller(UserHandle.getCallingUserId(), "provideContextImage");
 
             synchronized (mLock) {
                 final ContentSuggestionsPerUserService service = getServiceForUserLocked(userId);
@@ -141,7 +138,7 @@
                 int userId,
                 @NonNull SelectionsRequest selectionsRequest,
                 @NonNull ISelectionsCallback selectionsCallback) {
-            enforceCallerIsRecents(UserHandle.getCallingUserId(), "suggestContentSelections");
+            enforceCaller(UserHandle.getCallingUserId(), "suggestContentSelections");
 
             synchronized (mLock) {
                 final ContentSuggestionsPerUserService service = getServiceForUserLocked(userId);
@@ -160,7 +157,7 @@
                 int userId,
                 @NonNull ClassificationsRequest classificationsRequest,
                 @NonNull IClassificationsCallback callback) {
-            enforceCallerIsRecents(UserHandle.getCallingUserId(), "classifyContentSelections");
+            enforceCaller(UserHandle.getCallingUserId(), "classifyContentSelections");
 
             synchronized (mLock) {
                 final ContentSuggestionsPerUserService service = getServiceForUserLocked(userId);
@@ -177,7 +174,7 @@
         @Override
         public void notifyInteraction(
                 int userId, @NonNull String requestId, @NonNull Bundle bundle) {
-            enforceCallerIsRecents(UserHandle.getCallingUserId(), "notifyInteraction");
+            enforceCaller(UserHandle.getCallingUserId(), "notifyInteraction");
 
             synchronized (mLock) {
                 final ContentSuggestionsPerUserService service = getServiceForUserLocked(userId);
@@ -194,7 +191,7 @@
         @Override
         public void isEnabled(int userId, @NonNull IResultReceiver receiver)
                 throws RemoteException {
-            enforceCallerIsRecents(UserHandle.getCallingUserId(), "isEnabled");
+            enforceCaller(UserHandle.getCallingUserId(), "isEnabled");
 
             boolean isDisabled;
             synchronized (mLock) {