Avoid activity leak via Companion callback

Test: invoke associate() API and ensure it still works
Change-Id: I9aedb043b4b1f0d77f076d8753cd60ff7c98a7d6
diff --git a/core/java/android/app/SystemServiceRegistry.java b/core/java/android/app/SystemServiceRegistry.java
index 40fe6af..fcf1931 100644
--- a/core/java/android/app/SystemServiceRegistry.java
+++ b/core/java/android/app/SystemServiceRegistry.java
@@ -95,7 +95,6 @@
 import android.os.BatteryManager;
 import android.os.BatteryStats;
 import android.os.Build;
-import android.os.Debug;
 import android.os.DropBoxManager;
 import android.os.HardwarePropertiesManager;
 import android.os.IBatteryPropertiesRegistrar;
@@ -118,8 +117,6 @@
 import android.os.storage.StorageManager;
 import android.print.IPrintManager;
 import android.print.PrintManager;
-import android.view.autofill.AutofillManager;
-import android.view.autofill.IAutoFillManager;
 import android.service.oemlock.IOemLockService;
 import android.service.oemlock.OemLockManager;
 import android.service.persistentdata.IPersistentDataBlockService;
@@ -136,6 +133,8 @@
 import android.view.WindowManagerImpl;
 import android.view.accessibility.AccessibilityManager;
 import android.view.accessibility.CaptioningManager;
+import android.view.autofill.AutofillManager;
+import android.view.autofill.IAutoFillManager;
 import android.view.inputmethod.InputMethodManager;
 import android.view.textclassifier.TextClassificationManager;
 import android.view.textservice.TextServicesManager;
@@ -660,7 +659,7 @@
                                 ServiceManager.getService(Context.COMPANION_DEVICE_SERVICE);
                         ICompanionDeviceManager service =
                                 ICompanionDeviceManager.Stub.asInterface(iBinder);
-                        return new CompanionDeviceManager(service, ctx);
+                        return new CompanionDeviceManager(service, ctx.getOuterContext());
                     }});
 
         registerService(Context.CONSUMER_IR_SERVICE, ConsumerIrManager.class,
diff --git a/core/java/android/companion/CompanionDeviceManager.java b/core/java/android/companion/CompanionDeviceManager.java
index fac9e13..4e70e3f 100644
--- a/core/java/android/companion/CompanionDeviceManager.java
+++ b/core/java/android/companion/CompanionDeviceManager.java
@@ -21,11 +21,14 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.app.Activity;
+import android.app.Application;
 import android.app.PendingIntent;
 import android.content.ComponentName;
 import android.content.Context;
 import android.content.IntentSender;
 import android.content.pm.PackageManager;
+import android.os.Bundle;
 import android.os.Handler;
 import android.os.RemoteException;
 import android.service.notification.NotificationListenerService;
@@ -137,26 +140,11 @@
         }
         checkNotNull(request, "Request cannot be null");
         checkNotNull(callback, "Callback cannot be null");
-        final Handler finalHandler = Handler.mainIfNull(handler);
         try {
             mService.associate(
                     request,
-                    //TODO implicit pointer to outer class -> =null onDestroy
-                    //TODO onStop if isFinishing -> stopScan
-                    new IFindDeviceCallback.Stub() {
-                        @Override
-                        public void onSuccess(PendingIntent launcher) {
-                            finalHandler.post(() -> {
-                                callback.onDeviceFound(launcher.getIntentSender());
-                            });
-                        }
-
-                        @Override
-                        public void onFailure(CharSequence reason) {
-                            finalHandler.post(() -> callback.onFailure(reason));
-                        }
-                    },
-                    mContext.getPackageName());
+                    new CallbackProxy(request, callback, Handler.mainIfNull(handler)),
+                    getCallingPackage());
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
@@ -175,7 +163,7 @@
             return Collections.emptyList();
         }
         try {
-            return mService.getAssociations(mContext.getPackageName(), mContext.getUserId());
+            return mService.getAssociations(getCallingPackage(), mContext.getUserId());
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
@@ -200,7 +188,7 @@
             return;
         }
         try {
-            mService.disassociate(deviceMacAddress, mContext.getPackageName());
+            mService.disassociate(deviceMacAddress, getCallingPackage());
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
@@ -263,4 +251,57 @@
         }
         return featurePresent;
     }
+
+    private Activity getActivity() {
+        return (Activity) mContext;
+    }
+
+    private String getCallingPackage() {
+        return mContext.getPackageName();
+    }
+
+    private class CallbackProxy extends IFindDeviceCallback.Stub
+            implements Application.ActivityLifecycleCallbacks {
+
+        private Callback mCallback;
+        private Handler mHandler;
+        private AssociationRequest mRequest;
+
+        private CallbackProxy(AssociationRequest request, Callback callback, Handler handler) {
+            mCallback = callback;
+            mHandler = handler;
+            mRequest = request;
+            getActivity().getApplication().registerActivityLifecycleCallbacks(this);
+        }
+
+        @Override
+        public void onSuccess(PendingIntent launcher) {
+            mHandler.post(() -> mCallback.onDeviceFound(launcher.getIntentSender()));
+        }
+
+        @Override
+        public void onFailure(CharSequence reason) {
+            mHandler.post(() -> mCallback.onFailure(reason));
+        }
+
+        @Override
+        public void onActivityDestroyed(Activity activity) {
+            try {
+                mService.stopScan(mRequest, this, getCallingPackage());
+            } catch (RemoteException e) {
+                e.rethrowFromSystemServer();
+            }
+            getActivity().getApplication().unregisterActivityLifecycleCallbacks(this);
+            mCallback = null;
+            mHandler = null;
+            mRequest = null;
+        }
+
+        @Override public void onActivityCreated(Activity activity, Bundle savedInstanceState) {}
+        @Override public void onActivityStarted(Activity activity) {}
+        @Override public void onActivityResumed(Activity activity) {}
+        @Override public void onActivityPaused(Activity activity) {}
+        @Override public void onActivityStopped(Activity activity) {}
+        @Override public void onActivitySaveInstanceState(Activity activity, Bundle outState) {}
+    }
 }
diff --git a/core/java/android/companion/ICompanionDeviceManager.aidl b/core/java/android/companion/ICompanionDeviceManager.aidl
index d395208..561342e 100644
--- a/core/java/android/companion/ICompanionDeviceManager.aidl
+++ b/core/java/android/companion/ICompanionDeviceManager.aidl
@@ -30,6 +30,9 @@
     void associate(in AssociationRequest request,
         in IFindDeviceCallback callback,
         in String callingPackage);
+    void stopScan(in AssociationRequest request,
+        in IFindDeviceCallback callback,
+        in String callingPackage);
 
     List<String> getAssociations(String callingPackage, int userId);
     void disassociate(String deviceMacAddress, String callingPackage);
diff --git a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
index 6093241..73f1705 100644
--- a/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
+++ b/services/companion/java/com/android/server/companion/CompanionDeviceManagerService.java
@@ -110,10 +110,15 @@
     private final CompanionDeviceManagerImpl mImpl;
     private final ConcurrentMap<Integer, AtomicFile> mUidToStorage = new ConcurrentHashMap<>();
     private IDeviceIdleController mIdleController;
-    private IFindDeviceCallback mFindDeviceCallback;
     private ServiceConnection mServiceConnection;
     private IAppOpsService mAppOpsManager;
 
+    private IFindDeviceCallback mFindDeviceCallback;
+    private AssociationRequest mRequest;
+    private String mCallingPackage;
+
+    private final Object mLock = new Object();
+
     public CompanionDeviceManagerService(Context context) {
         super(context);
         mImpl = new CompanionDeviceManagerImpl();
@@ -156,8 +161,12 @@
     }
 
     private void cleanup() {
-        mServiceConnection = unbind(mServiceConnection);
-        mFindDeviceCallback = unlinkToDeath(mFindDeviceCallback, this, 0);
+        synchronized (mLock) {
+            mServiceConnection = unbind(mServiceConnection);
+            mFindDeviceCallback = unlinkToDeath(mFindDeviceCallback, this, 0);
+            mRequest = null;
+            mCallingPackage = null;
+        }
     }
 
     /**
@@ -222,6 +231,17 @@
         }
 
         @Override
+        public void stopScan(AssociationRequest request,
+                IFindDeviceCallback callback,
+                String callingPackage) {
+            if (Objects.equals(request, mRequest)
+                    && Objects.equals(callback, mFindDeviceCallback)
+                    && Objects.equals(callingPackage, mCallingPackage)) {
+                cleanup();
+            }
+        }
+
+        @Override
         public List<String> getAssociations(String callingPackage, int userId)
                 throws RemoteException {
             checkCallerIsSystemOr(callingPackage, userId);
@@ -340,7 +360,11 @@
                             "onServiceConnected(name = " + name + ", service = "
                                     + service + ")");
                 }
+
                 mFindDeviceCallback = findDeviceCallback;
+                mRequest = request;
+                mCallingPackage = callingPackage;
+
                 try {
                     mFindDeviceCallback.asBinder().linkToDeath(
                             CompanionDeviceManagerService.this, 0);
@@ -348,6 +372,7 @@
                     cleanup();
                     return;
                 }
+
                 try {
                     ICompanionDeviceDiscoveryService.Stub
                             .asInterface(service)