Merge "Add AppPredictionServiceResolverComparator" into qt-dev
am: ba3b157e32

Change-Id: I385cc811982f84186748b8565269fc40b4d6779a
diff --git a/core/java/com/android/internal/app/AbstractResolverComparator.java b/core/java/com/android/internal/app/AbstractResolverComparator.java
index e091aac..b7276a0 100644
--- a/core/java/com/android/internal/app/AbstractResolverComparator.java
+++ b/core/java/com/android/internal/app/AbstractResolverComparator.java
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package com.android.internal.app;
 
 import android.app.usage.UsageStatsManager;
@@ -20,7 +36,7 @@
 
     private static final int NUM_OF_TOP_ANNOTATIONS_TO_USE = 3;
 
-    protected AfterCompute mAfterCompute;
+    private AfterCompute mAfterCompute;
     protected final PackageManager mPm;
     protected final UsageStatsManager mUsm;
     protected String[] mAnnotations;
@@ -72,6 +88,13 @@
         mAfterCompute = afterCompute;
     }
 
+    protected final void afterCompute() {
+        final AfterCompute afterCompute = mAfterCompute;
+        if (afterCompute != null) {
+            afterCompute.afterCompute();
+        }
+    }
+
     @Override
     public final int compare(ResolvedComponentInfo lhsp, ResolvedComponentInfo rhsp) {
         final ResolveInfo lhs = lhsp.getResolveInfoAt(0);
diff --git a/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
new file mode 100644
index 0000000..cb44c67
--- /dev/null
+++ b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.internal.app;
+
+import static android.app.prediction.AppTargetEvent.ACTION_LAUNCH;
+
+import android.app.prediction.AppPredictor;
+import android.app.prediction.AppTarget;
+import android.app.prediction.AppTargetEvent;
+import android.app.prediction.AppTargetId;
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.pm.ResolveInfo;
+import android.os.UserHandle;
+import android.view.textclassifier.Log;
+
+import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Uses an {@link AppPredictor} to sort Resolver targets.
+ */
+class AppPredictionServiceResolverComparator extends AbstractResolverComparator {
+
+    private static final String TAG = "APSResolverComparator";
+
+    private final AppPredictor mAppPredictor;
+    private final Context mContext;
+    private final Map<ComponentName, Integer> mTargetRanks = new HashMap<>();
+    private final UserHandle mUser;
+
+    AppPredictionServiceResolverComparator(
+                Context context, Intent intent, AppPredictor appPredictor, UserHandle user) {
+        super(context, intent);
+        mContext = context;
+        mAppPredictor = appPredictor;
+        mUser = user;
+    }
+
+    @Override
+    int compare(ResolveInfo lhs, ResolveInfo rhs) {
+        Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
+                lhs.activityInfo.name));
+        Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
+                rhs.activityInfo.name));
+        if (lhsRank == null && rhsRank == null) {
+            return 0;
+        } else if (lhsRank == null) {
+            return -1;
+        } else if (rhsRank == null) {
+            return 1;
+        }
+        return lhsRank - rhsRank;
+    }
+
+    @Override
+    void compute(List<ResolvedComponentInfo> targets) {
+        List<AppTarget> appTargets = new ArrayList<>();
+        for (ResolvedComponentInfo target : targets) {
+            appTargets.add(new AppTarget.Builder(new AppTargetId(target.name.flattenToString()))
+                    .setTarget(target.name.getPackageName(), mUser)
+                    .setClassName(target.name.getClassName()).build());
+        }
+        mAppPredictor.sortTargets(appTargets, mContext.getMainExecutor(),
+                sortedAppTargets -> {
+                    for (int i = 0; i < sortedAppTargets.size(); i++) {
+                        mTargetRanks.put(new ComponentName(sortedAppTargets.get(i).getPackageName(),
+                                sortedAppTargets.get(i).getClassName()), i);
+                    }
+                    afterCompute();
+                });
+    }
+
+    @Override
+    float getScore(ComponentName name) {
+        Integer rank = mTargetRanks.get(name);
+        if (rank == null) {
+            Log.w(TAG, "Score requested for unknown component.");
+            return 0f;
+        }
+        int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
+        return 1.0f - (((float) rank) / consecutiveSumOfRanks);
+    }
+
+    @Override
+    void updateModel(ComponentName componentName) {
+        mAppPredictor.notifyAppTargetEvent(
+                new AppTargetEvent.Builder(
+                    new AppTarget.Builder(
+                        new AppTargetId(componentName.toString()),
+                        componentName.getPackageName(), mUser)
+                        .setClassName(componentName.getClassName()).build(),
+                    ACTION_LAUNCH).build());
+    }
+
+    @Override
+    void destroy() {
+        // Do nothing. App Predictor destruction is handled by caller.
+    }
+}
diff --git a/core/java/com/android/internal/app/ChooserActivity.java b/core/java/com/android/internal/app/ChooserActivity.java
index 54338bf..59e867f 100644
--- a/core/java/com/android/internal/app/ChooserActivity.java
+++ b/core/java/com/android/internal/app/ChooserActivity.java
@@ -150,6 +150,7 @@
      */
     // TODO(b/123089490): Replace with system flag
     private static final boolean USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS = false;
+    private static final boolean USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES = false;
     // TODO(b/123088566) Share these in a better way.
     private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share";
     public static final String LAUNCH_LOCATON_DIRECT_SHARE = "direct_share";
@@ -1387,6 +1388,15 @@
         return USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS ? getAppPredictor() : null;
     }
 
+    /**
+     * This will return an app predictor if it is enabled for share activity sorting
+     * and if one exists. Otherwise, it returns null.
+     */
+    @Nullable
+    private AppPredictor getAppPredictorForShareActivitesIfEnabled() {
+        return USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES ? getAppPredictor() : null;
+    }
+
     void onRefinementResult(TargetInfo selectedTarget, Intent matchingIntent) {
         if (mRefinementResultReceiver != null) {
             mRefinementResultReceiver.destroy();
@@ -1491,8 +1501,10 @@
                 PackageManager pm,
                 Intent targetIntent,
                 String referrerPackageName,
-                int launchedFromUid) {
-            super(context, pm, targetIntent, referrerPackageName, launchedFromUid);
+                int launchedFromUid,
+                AbstractResolverComparator resolverComparator) {
+            super(context, pm, targetIntent, referrerPackageName, launchedFromUid,
+                    resolverComparator);
         }
 
         @Override
@@ -1520,13 +1532,24 @@
 
     @VisibleForTesting
     protected ResolverListController createListController() {
+        AppPredictor appPredictor = getAppPredictorForShareActivitesIfEnabled();
+        AbstractResolverComparator resolverComparator;
+        if (appPredictor != null) {
+            resolverComparator = new AppPredictionServiceResolverComparator(this, getTargetIntent(),
+                    appPredictor, getUser());
+        } else {
+            resolverComparator =
+                    new ResolverRankerServiceResolverComparator(this, getTargetIntent(),
+                        getReferrerPackageName(), null);
+        }
+
         return new ChooserListController(
                 this,
                 mPm,
                 getTargetIntent(),
                 getReferrerPackageName(),
-                mLaunchedFromUid
-                );
+                mLaunchedFromUid,
+                resolverComparator);
     }
 
     @VisibleForTesting
diff --git a/core/java/com/android/internal/app/ResolverListController.java b/core/java/com/android/internal/app/ResolverListController.java
index a3cfa87..5f92cdd 100644
--- a/core/java/com/android/internal/app/ResolverListController.java
+++ b/core/java/com/android/internal/app/ResolverListController.java
@@ -63,14 +63,24 @@
             Intent targetIntent,
             String referrerPackage,
             int launchedFromUid) {
+        this(context, pm, targetIntent, referrerPackage, launchedFromUid,
+                    new ResolverRankerServiceResolverComparator(
+                        context, targetIntent, referrerPackage, null));
+    }
+
+    public ResolverListController(
+            Context context,
+            PackageManager pm,
+            Intent targetIntent,
+            String referrerPackage,
+            int launchedFromUid,
+            AbstractResolverComparator resolverComparator) {
         mContext = context;
         mpm = pm;
         mLaunchedFromUid = launchedFromUid;
         mTargetIntent = targetIntent;
         mReferrerPackage = referrerPackage;
-        mResolverComparator =
-                new ResolverRankerServiceResolverComparator(
-                    mContext, mTargetIntent, mReferrerPackage, null);
+        mResolverComparator = resolverComparator;
     }
 
     @VisibleForTesting
diff --git a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
index 9bf4f01..726b186 100644
--- a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
+++ b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
@@ -126,7 +126,7 @@
                             Log.e(TAG, "Receiving null prediction results.");
                         }
                         mHandler.removeMessages(RESOLVER_RANKER_RESULT_TIMEOUT);
-                        mAfterCompute.afterCompute();
+                        afterCompute();
                     }
                     break;
 
@@ -135,7 +135,7 @@
                         Log.d(TAG, "RESOLVER_RANKER_RESULT_TIMEOUT; unbinding services");
                     }
                     mHandler.removeMessages(RESOLVER_RANKER_SERVICE_RESULT);
-                    mAfterCompute.afterCompute();
+                    afterCompute();
                     break;
 
                 default:
@@ -149,7 +149,6 @@
         super(context, intent);
         mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
         mReferrerPackage = referrerPackage;
-        mAfterCompute = afterCompute;
         mContext = context;
 
         mCurrentTime = System.currentTimeMillis();
@@ -157,6 +156,7 @@
         mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
         mAction = intent.getAction();
         mRankerServiceName = new ComponentName(mContext, this.getClass());
+        setCallBack(afterCompute);
     }
 
     // compute features for each target according to usage stats of targets.
@@ -328,9 +328,7 @@
             mContext.unbindService(mConnection);
             mConnection.destroy();
         }
-        if (mAfterCompute != null) {
-            mAfterCompute.afterCompute();
-        }
+        afterCompute();
         if (DEBUG) {
             Log.d(TAG, "Unbinded Resolver Ranker.");
         }
@@ -513,9 +511,7 @@
                 Log.e(TAG, "Error in Predict: " + e);
             }
         }
-        if (mAfterCompute != null) {
-            mAfterCompute.afterCompute();
-        }
+        afterCompute();
     }
 
     // adds select prob as the default values, according to a pre-trained Logistic Regression model.